Commit 30af93f2 authored by chenpangpang's avatar chenpangpang
Browse files

feat: gpu初始提交

parent 68e98ab8
Pipeline #2159 canceled with stages
import math
import numpy as np
import torch
from einops import repeat
def timestep_embedding(time_steps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param time_steps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=time_steps.device)
args = time_steps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(time_steps, "b -> b d", d=dim)
return embedding
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == "linear":
betas = (
torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
)
elif schedule == "cosine":
time_steps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = time_steps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == "sqrt":
betas = (
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_time_steps(
ddim_discr_method, num_ddim_time_steps, num_ddpm_time_steps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_time_steps // num_ddim_time_steps
ddim_time_steps = np.asarray(list(range(0, num_ddpm_time_steps, c)))
steps_out = ddim_time_steps + 1
elif ddim_discr_method == "quad":
ddim_time_steps = (
(np.linspace(0, np.sqrt(num_ddpm_time_steps * 0.8), num_ddim_time_steps))
** 2
).astype(int)
steps_out = ddim_time_steps + 1
elif ddim_discr_method == "uniform_trailing":
c = num_ddpm_time_steps / num_ddim_time_steps
ddim_time_steps = np.flip(
np.round(np.arange(num_ddpm_time_steps, 0, -c))
).astype(np.int64)
steps_out = ddim_time_steps - 1
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_time_steps.shape[0] == num_ddim_time_steps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
if verbose:
print(f"Selected time_steps for ddim sampler: {steps_out}")
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_time_steps, eta, verbose=True):
# select alphas for computing the variance schedule
# print(f'ddim_time_steps={ddim_time_steps}, len_alphacums={len(alphacums)}')
alphas = alphacums[ddim_time_steps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_time_steps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
)
print(
f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_time_steps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_time_steps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_time_steps):
t1 = i / num_diffusion_time_steps
t2 = (i + 1) / num_diffusion_time_steps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`numpy.ndarray`):
the betas that the scheduler is being initialized with.
Returns:
`numpy.ndarray`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = np.concatenate([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
factor = guidance_rescale * (std_text / std_cfg) + (1 - guidance_rescale)
return noise_cfg * factor
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from functools import partial
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
from core.common import (
gradient_checkpoint,
exists,
default,
)
from core.basics import zero_module
class RelativePosition(nn.Module):
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(
torch.Tensor(max_relative_position * 2 + 1, num_units)
)
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
device = self.embeddings_table.device
range_vec_q = torch.arange(length_q, device=device)
range_vec_k = torch.arange(length_k, device=device)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(
distance_mat, -self.max_relative_position, self.max_relative_position
)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = final_mat.long()
embeddings = self.embeddings_table[final_mat]
return embeddings
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
relative_position=False,
temporal_length=None,
video_length=None,
image_cross_attention=False,
image_cross_attention_scale=1.0,
image_cross_attention_scale_learnable=False,
text_context_len=77,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.relative_position = relative_position
if self.relative_position:
assert temporal_length is not None
self.relative_position_k = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length
)
self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length
)
else:
# only used for spatial attention, while NOT for temporal attention
if XFORMERS_IS_AVAILBLE and temporal_length is None:
self.forward = self.efficient_forward
self.video_length = video_length
self.image_cross_attention = image_cross_attention
self.image_cross_attention_scale = image_cross_attention_scale
self.text_context_len = text_context_len
self.image_cross_attention_scale_learnable = (
image_cross_attention_scale_learnable
)
if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
if image_cross_attention_scale_learnable:
self.register_parameter("alpha", nn.Parameter(torch.tensor(0.0)))
def forward(self, x, context=None, mask=None):
spatial_self_attn = context is None
k_ip, v_ip, out_ip = None, None, None
h = self.heads
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
context, context_image = (
context[:, : self.text_context_len, :],
context[:, self.text_context_len :, :],
)
k = self.to_k(context)
v = self.to_v(context)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
else:
if not spatial_self_attn:
context = context[:, : self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if self.relative_position:
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
k2 = self.relative_position_k(len_q, len_k)
sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale
sim += sim2
del k
if exists(mask):
# feasible for causal attention mask only
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b i j -> (b h) i j", h=h)
sim.masked_fill_(~(mask > 0.5), max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", sim, v)
if self.relative_position:
v2 = self.relative_position_v(len_q, len_v)
out2 = einsum("b t s, t s d -> b t d", sim, v2)
out += out2
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
# for image cross-attention
if k_ip is not None:
k_ip, v_ip = map(
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (k_ip, v_ip)
)
sim_ip = torch.einsum("b i d, b j d -> b i j", q, k_ip) * self.scale
del k_ip
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.einsum("b i j, b j d -> b i d", sim_ip, v_ip)
out_ip = rearrange(out_ip, "(b h) n d -> b n (h d)", h=h)
if out_ip is not None:
if self.image_cross_attention_scale_learnable:
out = out + self.image_cross_attention_scale * out_ip * (
torch.tanh(self.alpha) + 1
)
else:
out = out + self.image_cross_attention_scale * out_ip
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None):
spatial_self_attn = context is None
k_ip, v_ip, out_ip = None, None, None
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
context, context_image = (
context[:, : self.text_context_len, :],
context[:, self.text_context_len :, :],
)
k = self.to_k(context)
v = self.to_v(context)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
else:
if not spatial_self_attn:
context = context[:, : self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
# for image cross-attention
if k_ip is not None:
k_ip, v_ip = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(k_ip, v_ip),
)
out_ip = xformers.ops.memory_efficient_attention(
q, k_ip, v_ip, attn_bias=None, op=None
)
out_ip = (
out_ip.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
if out_ip is not None:
if self.image_cross_attention_scale_learnable:
out = out + self.image_cross_attention_scale * out_ip * (
torch.tanh(self.alpha) + 1
)
else:
out = out + self.image_cross_attention_scale * out_ip
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attention_cls=None,
video_length=None,
image_cross_attention=False,
image_cross_attention_scale=1.0,
image_cross_attention_scale_learnable=False,
text_context_len=77,
enable_lora=False,
):
super().__init__()
attn_cls = CrossAttention if attention_cls is None else attention_cls
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
)
self.ff = FeedForward(
dim, dropout=dropout, glu=gated_ff, enable_lora=enable_lora
)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
video_length=video_length,
image_cross_attention=image_cross_attention,
image_cross_attention_scale=image_cross_attention_scale,
image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
text_context_len=text_context_len,
)
self.image_cross_attention = image_cross_attention
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
self.enable_lora = enable_lora
def forward(self, x, context=None, mask=None, with_lora=False, **kwargs):
# implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
# should not be (x), otherwise *input_tuple will decouple x into multiple arguments
input_tuple = (x,)
if context is not None:
input_tuple = (x, context)
if mask is not None:
_forward = partial(self._forward, mask=None, with_lora=with_lora)
else:
_forward = partial(self._forward, mask=mask, with_lora=with_lora)
return gradient_checkpoint(
_forward, input_tuple, self.parameters(), self.checkpoint
)
def _forward(self, x, context=None, mask=None, with_lora=False):
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
mask=mask,
)
+ x
)
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x), with_lora=with_lora) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data in spatial axis.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
use_checkpoint=True,
disable_self_attn=False,
use_linear=False,
video_length=None,
image_cross_attention=False,
image_cross_attention_scale_learnable=False,
enable_lora=False,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.enable_lora = enable_lora
attention_cls = None
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
attention_cls=attention_cls,
video_length=video_length,
image_cross_attention=image_cross_attention,
image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
enable_lora=self.enable_lora,
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None, with_lora=False, **kwargs):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, with_lora=with_lora, **kwargs)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
class TemporalTransformer(nn.Module):
"""
Transformer block for image-like data in temporal axis.
First, reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
use_checkpoint=True,
use_linear=False,
only_self_att=True,
causal_attention=False,
causal_block_size=1,
relative_position=False,
temporal_length=None,
use_extra_spatial_temporal_self_attention=False,
enable_lora=False,
full_spatial_temporal_attention=False,
enhance_multi_view_correspondence=False,
):
super().__init__()
self.only_self_att = only_self_att
self.relative_position = relative_position
self.causal_attention = causal_attention
self.causal_block_size = causal_block_size
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.proj_in = nn.Conv1d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
if not use_linear:
self.proj_in = nn.Conv1d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
if relative_position:
assert temporal_length is not None
attention_cls = partial(
CrossAttention, relative_position=True, temporal_length=temporal_length
)
else:
attention_cls = partial(CrossAttention, temporal_length=temporal_length)
if self.causal_attention:
assert temporal_length is not None
self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
if self.only_self_att:
context_dim = None
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
attention_cls=attention_cls,
checkpoint=use_checkpoint,
enable_lora=enable_lora,
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
self.use_extra_spatial_temporal_self_attention = (
use_extra_spatial_temporal_self_attention
)
if use_extra_spatial_temporal_self_attention:
from core.modules.attention_mv import MultiViewSelfAttentionTransformer
self.extra_spatial_time_self_attention = MultiViewSelfAttentionTransformer(
in_channels=in_channels,
n_heads=n_heads,
d_head=d_head,
num_views=temporal_length,
depth=depth,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
full_spatial_temporal_attention=full_spatial_temporal_attention,
enhance_multi_view_correspondence=enhance_multi_view_correspondence,
)
def forward(self, x, context=None, with_lora=False, time_steps=None):
b, c, t, h, w = x.shape
x_in = x
x = self.norm(x)
x = rearrange(x, "b c t h w -> (b h w) c t").contiguous()
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "bhw c t -> bhw t c").contiguous()
if self.use_linear:
x = self.proj_in(x)
temp_mask = None
if self.causal_attention:
# slice the from mask map
temp_mask = self.mask[:, :t, :t].to(x.device)
if temp_mask is not None:
mask = temp_mask.to(x.device)
mask = repeat(mask, "l i j -> (l bhw) i j", bhw=b * h * w)
else:
mask = None
if self.only_self_att:
# note: if no context is given, cross-attention defaults to self-attention
for i, block in enumerate(self.transformer_blocks):
x = block(x, mask=mask, with_lora=with_lora)
x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous()
else:
x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous()
context = rearrange(context, "(b t) l con -> b t l con", t=t).contiguous()
for i, block in enumerate(self.transformer_blocks):
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_j = repeat(
context[j], "t l con -> (t r) l con", r=(h * w) // t, t=t
).contiguous()
# note: causal mask will not applied in cross-attention case
x[j] = block(x[j], context=context_j, with_lora=with_lora)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) t c -> b c t h w", h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, "b hw t c -> (b hw) c t").contiguous()
x = self.proj_out(x)
x = rearrange(x, "(b h w) c t -> b c t h w", b=b, h=h, w=w).contiguous()
res = x + x_in
if self.use_extra_spatial_temporal_self_attention:
res = rearrange(res, "b c t h w -> (b t) c h w", b=b, h=h, w=w).contiguous()
res = self.extra_spatial_time_self_attention(res, time_steps=time_steps)
res = rearrange(res, "(b t) c h w -> b c t h w", b=b, h=h, w=w).contiguous()
return res
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out=None,
mult=4,
glu=False,
dropout=0.0,
enable_lora=False,
lora_rank=32,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
self.enable_lora = enable_lora
self.lora_rank = lora_rank
self.lora_alpha = 16
if self.enable_lora:
assert (
self.lora_rank is not None
), "`lora_rank` must be given when `enable_lora` is True."
assert (
0 < self.lora_rank < min(dim, dim_out)
), f"`lora_rank` must be range [0, min(inner_dim={inner_dim}, dim_out={dim_out})], but got {self.lora_rank}."
self.lora_a = nn.Parameter(
torch.zeros((inner_dim, self.lora_rank), requires_grad=True)
)
self.lora_b = nn.Parameter(
torch.zeros((self.lora_rank, dim_out), requires_grad=True)
)
self.scaling = self.lora_alpha / self.lora_rank
def forward(self, x, with_lora=False):
if with_lora:
projected_x = self.net[1](self.net[0](x))
lora_x = (
torch.matmul(projected_x, torch.matmul(self.lora_a, self.lora_b))
* self.scaling
)
original_x = self.net[2](projected_x)
return original_x + lora_x
else:
return self.net(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from core.common import gradient_checkpoint
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
print(f"XFORMERS_IS_AVAILBLE: {XFORMERS_IS_AVAILBLE}")
def get_group_norm_layer(in_channels):
if in_channels < 32:
if in_channels % 2 == 0:
num_groups = in_channels // 2
else:
num_groups = in_channels
else:
num_groups = 32
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
if dim_out is None:
dim_out = dim
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class SpatialTemporalAttention(nn.Module):
"""Uses xformers to implement efficient epipolar masking for cross-attention between views."""
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
if context_dim is None:
context_dim = query_dim
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op = None
def forward(self, x, context=None, enhance_multi_view_correspondence=False):
q = self.to_q(x)
if context is None:
context = x
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
if enhance_multi_view_correspondence:
with torch.no_grad():
normalized_x = torch.nn.functional.normalize(x.detach(), p=2, dim=-1)
cosine_sim_map = torch.bmm(normalized_x, normalized_x.transpose(-1, -2))
attn_bias = torch.where(cosine_sim_map > 0.0, 0.0, -1e9).to(
dtype=q.dtype
)
attn_bias = rearrange(
attn_bias.unsqueeze(1).expand(-1, self.heads, -1, -1),
"b h d1 d2 -> (b h) d1 d2",
).detach()
else:
attn_bias = None
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=attn_bias, op=self.attention_op
)
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
del q, k, v, attn_bias
return self.to_out(out)
class MultiViewSelfAttentionTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
gated_ff=True,
use_checkpoint=True,
full_spatial_temporal_attention=False,
enhance_multi_view_correspondence=False,
):
super().__init__()
attn_cls = SpatialTemporalAttention
# self.self_attention_only = self_attention_only
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
if enhance_multi_view_correspondence:
# Zero initalization when MVCorr is enabled.
zero_module_fn = zero_module
else:
def zero_module_fn(x):
return x
self.attn2 = zero_module_fn(
attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=None,
)
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint
self.full_spatial_temporal_attention = full_spatial_temporal_attention
self.enhance_multi_view_correspondence = enhance_multi_view_correspondence
def forward(self, x, time_steps=None):
return gradient_checkpoint(
self.many_stream_forward, (x, time_steps), None, flag=self.use_checkpoint
)
def many_stream_forward(self, x, time_steps=None):
n, v, hw = x.shape[:3]
x = rearrange(x, "n v hw c -> n (v hw) c")
x = (
self.attn1(
self.norm1(x), context=None, enhance_multi_view_correspondence=False
)
+ x
)
if not self.full_spatial_temporal_attention:
x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
x = rearrange(x, "n v hw c -> (n v) hw c")
x = (
self.attn2(
self.norm2(x),
context=None,
enhance_multi_view_correspondence=self.enhance_multi_view_correspondence
and hw <= 256,
)
+ x
)
x = self.ff(self.norm3(x)) + x
if self.full_spatial_temporal_attention:
x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
else:
x = rearrange(x, "(n v) hw c -> n v hw c", v=v)
return x
class MultiViewSelfAttentionTransformer(nn.Module):
"""Spatial Transformer block with post init to add cross attn."""
def __init__(
self,
in_channels,
n_heads,
d_head,
num_views,
depth=1,
dropout=0.0,
use_linear=True,
use_checkpoint=True,
zero_out_initialization=True,
full_spatial_temporal_attention=False,
enhance_multi_view_correspondence=False,
):
super().__init__()
self.num_views = num_views
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = get_group_norm_layer(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
MultiViewSelfAttentionTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
use_checkpoint=use_checkpoint,
full_spatial_temporal_attention=full_spatial_temporal_attention,
enhance_multi_view_correspondence=enhance_multi_view_correspondence,
)
for d in range(depth)
]
)
self.zero_out_initialization = zero_out_initialization
if zero_out_initialization:
_zero_func = zero_module
else:
def _zero_func(x):
return x
if not use_linear:
self.proj_out = _zero_func(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = _zero_func(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, time_steps=None):
# x: bt c h w
_, c, h, w = x.shape
n_views = self.num_views
x_in = x
x = self.norm(x)
x = rearrange(x, "(n v) c h w -> n v (h w) c", v=n_views)
if self.use_linear:
x = rearrange(x, "n v x c -> (n v) x c")
x = self.proj_in(x)
x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
for i, block in enumerate(self.transformer_blocks):
x = block(x, time_steps=time_steps)
if self.use_linear:
x = rearrange(x, "n v x c -> (n v) x c")
x = self.proj_out(x)
x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
x = rearrange(x, "n v (h w) c -> (n v) c h w", h=h, w=w).contiguous()
return x + x_in
import math
import torch
import torch as th
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn, einsum
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
from core.common import gradient_checkpoint, exists, default
from core.basics import conv_nd, zero_module, normalization
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class RelativePosition(nn.Module):
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(
th.Tensor(max_relative_position * 2 + 1, num_units)
)
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
device = self.embeddings_table.device
range_vec_q = th.arange(length_q, device=device)
range_vec_k = th.arange(length_k, device=device)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = th.clamp(
distance_mat, -self.max_relative_position, self.max_relative_position
)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = final_mat.long()
embeddings = self.embeddings_table[final_mat]
return embeddings
class TemporalCrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
# For relative positional representation and image-video joint training.
temporal_length=None,
image_length=None, # For image-video joint training.
# whether use relative positional representation in temporal attention.
use_relative_position=False,
# For image-video joint training.
img_video_joint_train=False,
use_tempoal_causal_attn=False,
bidirectional_causal_attn=False,
tempoal_attn_type=None,
joint_train_mode="same_batch",
**kwargs,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.context_dim = context_dim
self.scale = dim_head**-0.5
self.heads = heads
self.temporal_length = temporal_length
self.use_relative_position = use_relative_position
self.img_video_joint_train = img_video_joint_train
self.bidirectional_causal_attn = bidirectional_causal_attn
self.joint_train_mode = joint_train_mode
assert joint_train_mode in ["same_batch", "diff_batch"]
self.tempoal_attn_type = tempoal_attn_type
if bidirectional_causal_attn:
assert use_tempoal_causal_attn
if tempoal_attn_type:
assert tempoal_attn_type in ["sparse_causal", "sparse_causal_first"]
assert not use_tempoal_causal_attn
assert not (
img_video_joint_train and (self.joint_train_mode == "same_batch")
)
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
assert not (
img_video_joint_train
and (self.joint_train_mode == "same_batch")
and use_tempoal_causal_attn
)
if img_video_joint_train:
if self.joint_train_mode == "same_batch":
mask = torch.ones(
[1, temporal_length + image_length, temporal_length + image_length]
)
mask[:, temporal_length:, :] = 0
mask[:, :, temporal_length:] = 0
self.mask = mask
else:
self.mask = None
elif use_tempoal_causal_attn:
# normal causal attn
self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
elif tempoal_attn_type == "sparse_causal":
# true indicates keeping
mask1 = torch.tril(torch.ones([1, temporal_length, temporal_length])).bool()
# initialize to same shape with mask1
mask2 = torch.zeros([1, temporal_length, temporal_length])
mask2[:, 2:temporal_length, : temporal_length - 2] = torch.tril(
torch.ones([1, temporal_length - 2, temporal_length - 2])
)
mask2 = (1 - mask2).bool() # false indicates masking
self.mask = mask1 & mask2
elif tempoal_attn_type == "sparse_causal_first":
# true indicates keeping
mask1 = torch.tril(torch.ones([1, temporal_length, temporal_length])).bool()
mask2 = torch.zeros([1, temporal_length, temporal_length])
mask2[:, 2:temporal_length, 1 : temporal_length - 1] = torch.tril(
torch.ones([1, temporal_length - 2, temporal_length - 2])
)
mask2 = (1 - mask2).bool() # false indicates masking
self.mask = mask1 & mask2
else:
self.mask = None
if use_relative_position:
assert temporal_length is not None
self.relative_position_k = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length
)
self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
nn.init.constant_(self.to_q.weight, 0)
nn.init.constant_(self.to_k.weight, 0)
nn.init.constant_(self.to_v.weight, 0)
nn.init.constant_(self.to_out[0].weight, 0)
nn.init.constant_(self.to_out[0].bias, 0)
def forward(self, x, context=None, mask=None):
nh = self.heads
out = x
q = self.to_q(out)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if self.use_relative_position:
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
k2 = self.relative_position_k(len_q, len_k)
sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale
sim += sim2
if exists(self.mask):
if mask is None:
mask = self.mask.to(sim.device)
else:
# .to(sim.device)
mask = self.mask.to(sim.device).bool() & mask
else:
mask = mask
if mask is not None:
max_neg_value = -1e9
sim = sim + (1 - mask.float()) * max_neg_value # 1=masking,0=no masking
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
if self.bidirectional_causal_attn:
mask_reverse = torch.triu(
torch.ones(
[1, self.temporal_length, self.temporal_length], device=sim.device
)
)
sim_reverse = sim.float().masked_fill(mask_reverse == 0, max_neg_value)
attn_reverse = sim_reverse.softmax(dim=-1)
out_reverse = einsum("b i j, b j d -> b i d", attn_reverse, v)
out += out_reverse
if self.use_relative_position:
v2 = self.relative_position_v(len_q, len_v)
out2 = einsum("b t s, t s d -> b t d", attn, v2)
out += out2
out = rearrange(out, "(b h) n d -> b n (h d)", h=nh)
return self.to_out(out)
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
sa_shared_kv=False,
shared_type="only_first",
**kwargs,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.sa_shared_kv = sa_shared_kv
assert shared_type in [
"only_first",
"all_frames",
"first_and_prev",
"only_prev",
"full",
"causal",
"full_qkv",
]
self.shared_type = shared_type
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
if XFORMERS_IS_AVAILBLE:
self.forward = self.efficient_forward
def forward(self, x, context=None, mask=None):
h = self.heads
b = x.shape[0]
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if self.sa_shared_kv:
if self.shared_type == "only_first":
k, v = map(
lambda xx: rearrange(xx[0].unsqueeze(0), "b n c -> (b n) c")
.unsqueeze(0)
.repeat(b, 1, 1),
(k, v),
)
else:
raise NotImplementedError
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class VideoSpatialCrossAttention(CrossAttention):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0):
super().__init__(query_dim, context_dim, heads, dim_head, dropout)
def forward(self, x, context=None, mask=None):
b, c, t, h, w = x.shape
if context is not None:
context = context.repeat(t, 1, 1)
x = super.forward(spatial_attn_reshape(x), context=context) + x
return spatial_attn_reshape_back(x, b, h)
class BasicTransformerBlockST(nn.Module):
def __init__(
self,
# Spatial Stuff
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
# Temporal Stuff
temporal_length=None,
image_length=None,
use_relative_position=True,
img_video_joint_train=False,
cross_attn_on_tempoal=False,
temporal_crossattn_type="selfattn",
order="stst",
temporalcrossfirst=False,
temporal_context_dim=None,
split_stcontext=False,
local_spatial_temporal_attn=False,
window_size=2,
**kwargs,
):
super().__init__()
# Self attention
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
**kwargs,
)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
# cross attention if context is not None
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
**kwargs,
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
self.order = order
assert self.order in ["stst", "sstt", "st_parallel"]
self.temporalcrossfirst = temporalcrossfirst
self.split_stcontext = split_stcontext
self.local_spatial_temporal_attn = local_spatial_temporal_attn
if self.local_spatial_temporal_attn:
assert self.order == "stst"
assert self.order == "stst"
self.window_size = window_size
if not split_stcontext:
temporal_context_dim = context_dim
# Temporal attention
assert temporal_crossattn_type in ["selfattn", "crossattn", "skip"]
self.temporal_crossattn_type = temporal_crossattn_type
self.attn1_tmp = TemporalCrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
temporal_length=temporal_length,
image_length=image_length,
use_relative_position=use_relative_position,
img_video_joint_train=img_video_joint_train,
**kwargs,
)
self.attn2_tmp = TemporalCrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
# cross attn
context_dim=(
temporal_context_dim if temporal_crossattn_type == "crossattn" else None
),
# temporal attn
temporal_length=temporal_length,
image_length=image_length,
use_relative_position=use_relative_position,
img_video_joint_train=img_video_joint_train,
**kwargs,
)
self.norm4 = nn.LayerNorm(dim)
self.norm5 = nn.LayerNorm(dim)
def forward(
self,
x,
context=None,
temporal_context=None,
no_temporal_attn=None,
attn_mask=None,
**kwargs,
):
if not self.split_stcontext:
# st cross attention use the same context vector
temporal_context = context.detach().clone()
if context is None and temporal_context is None:
# self-attention models
if no_temporal_attn:
raise NotImplementedError
return gradient_checkpoint(
self._forward_nocontext, (x), self.parameters(), self.checkpoint
)
else:
# cross-attention models
if no_temporal_attn:
forward_func = self._forward_no_temporal_attn
else:
forward_func = self._forward
inputs = (
(x, context, temporal_context)
if temporal_context is not None
else (x, context)
)
return gradient_checkpoint(
forward_func, inputs, self.parameters(), self.checkpoint
)
def _forward(
self,
x,
context=None,
temporal_context=None,
mask=None,
no_temporal_attn=None,
):
assert x.dim() == 5, f"x shape = {x.shape}"
b, c, t, h, w = x.shape
if self.order in ["stst", "sstt"]:
x = self._st_cross_attn(
x,
context,
temporal_context=temporal_context,
order=self.order,
mask=mask,
) # no_temporal_attn=no_temporal_attn,
elif self.order == "st_parallel":
x = self._st_cross_attn_parallel(
x,
context,
temporal_context=temporal_context,
order=self.order,
) # no_temporal_attn=no_temporal_attn,
else:
raise NotImplementedError
x = self.ff(self.norm3(x)) + x
if (no_temporal_attn is None) or (not no_temporal_attn):
x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
elif no_temporal_attn:
x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
return x
def _forward_no_temporal_attn(
self,
x,
context=None,
temporal_context=None,
):
assert x.dim() == 5, f"x shape = {x.shape}"
b, c, t, h, w = x.shape
if self.order in ["stst", "sstt"]:
mask = torch.zeros([1, t, t], device=x.device).bool()
x = self._st_cross_attn(
x,
context,
temporal_context=temporal_context,
order=self.order,
mask=mask,
)
elif self.order == "st_parallel":
x = self._st_cross_attn_parallel(
x,
context,
temporal_context=temporal_context,
order=self.order,
no_temporal_attn=True,
)
else:
raise NotImplementedError
x = self.ff(self.norm3(x)) + x
x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
return x
def _forward_nocontext(self, x, no_temporal_attn=None):
assert x.dim() == 5, f"x shape = {x.shape}"
b, c, t, h, w = x.shape
if self.order in ["stst", "sstt"]:
x = self._st_cross_attn(
x, order=self.order, no_temporal_attn=no_temporal_attn
)
elif self.order == "st_parallel":
x = self._st_cross_attn_parallel(
x, order=self.order, no_temporal_attn=no_temporal_attn
)
else:
raise NotImplementedError
x = self.ff(self.norm3(x)) + x
x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
return x
def _st_cross_attn(
self, x, context=None, temporal_context=None, order="stst", mask=None
):
b, c, t, h, w = x.shape
if order == "stst":
x = rearrange(x, "b c t h w -> (b t) (h w) c")
x = self.attn1(self.norm1(x)) + x
x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
if self.local_spatial_temporal_attn:
x = local_spatial_temporal_attn_reshape(x, window_size=self.window_size)
else:
x = rearrange(x, "b c t h w -> (b h w) t c")
x = self.attn1_tmp(self.norm4(x), mask=mask) + x
if self.local_spatial_temporal_attn:
x = local_spatial_temporal_attn_reshape_back(
x, window_size=self.window_size, b=b, h=h, w=w, t=t
)
else:
x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w) # 3d -> 5d
# spatial cross attention
x = rearrange(x, "b c t h w -> (b t) (h w) c")
if context is not None:
if context.shape[0] == t: # img captions no_temporal_attn or
context_ = context
else:
context_ = []
for i in range(context.shape[0]):
context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
context_ = torch.cat(context_, dim=0)
else:
context_ = None
x = self.attn2(self.norm2(x), context=context_) + x
# temporal cross attention
# if (no_temporal_attn is None) or (not no_temporal_attn):
x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
x = rearrange(x, "b c t h w -> (b h w) t c")
if self.temporal_crossattn_type == "crossattn":
# tmporal cross attention
if temporal_context is not None:
# print(f'STATTN context={context.shape}, temporal_context={temporal_context.shape}')
temporal_context = torch.cat(
[context, temporal_context], dim=1
) # blc
# print(f'STATTN after concat temporal_context={temporal_context.shape}')
temporal_context = temporal_context.repeat(h * w, 1, 1)
# print(f'after repeat temporal_context={temporal_context.shape}')
else:
temporal_context = context[0:1, ...].repeat(h * w, 1, 1)
# print(f'STATTN after concat x={x.shape}')
x = (
self.attn2_tmp(self.norm5(x), context=temporal_context, mask=mask)
+ x
)
elif self.temporal_crossattn_type == "selfattn":
# temporal self attention
x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
elif self.temporal_crossattn_type == "skip":
# no temporal cross and self attention
pass
else:
raise NotImplementedError
elif order == "sstt":
# spatial self attention
x = rearrange(x, "b c t h w -> (b t) (h w) c")
x = self.attn1(self.norm1(x)) + x
# spatial cross attention
context_ = context.repeat(t, 1, 1) if context is not None else None
x = self.attn2(self.norm2(x), context=context_) + x
x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
if (no_temporal_attn is None) or (not no_temporal_attn):
if self.temporalcrossfirst:
# temporal cross attention
if self.temporal_crossattn_type == "crossattn":
# if temporal_context is not None:
temporal_context = context.repeat(h * w, 1, 1)
x = (
self.attn2_tmp(
self.norm5(x), context=temporal_context, mask=mask
)
+ x
)
elif self.temporal_crossattn_type == "selfattn":
x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
elif self.temporal_crossattn_type == "skip":
pass
else:
raise NotImplementedError
# temporal self attention
x = rearrange(x, "b c t h w -> (b h w) t c")
x = self.attn1_tmp(self.norm4(x), mask=mask) + x
else:
# temporal self attention
x = rearrange(x, "b c t h w -> (b h w) t c")
x = self.attn1_tmp(self.norm4(x), mask=mask) + x
# temporal cross attention
if self.temporal_crossattn_type == "crossattn":
if temporal_context is not None:
temporal_context = context.repeat(h * w, 1, 1)
x = (
self.attn2_tmp(
self.norm5(x), context=temporal_context, mask=mask
)
+ x
)
elif self.temporal_crossattn_type == "selfattn":
x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
elif self.temporal_crossattn_type == "skip":
pass
else:
raise NotImplementedError
else:
raise NotImplementedError
return x
def _st_cross_attn_parallel(
self, x, context=None, temporal_context=None, order="sst", no_temporal_attn=None
):
"""order: x -> Self Attn -> Cross Attn -> attn_s
x -> Temp Self Attn -> attn_t
x' = x + attn_s + attn_t
"""
if no_temporal_attn is not None:
raise NotImplementedError
B, C, T, H, W = x.shape
# spatial self attention
h = x
h = rearrange(h, "b c t h w -> (b t) (h w) c")
h = self.attn1(self.norm1(h)) + h
# spatial cross
# context_ = context.repeat(T, 1, 1) if context is not None else None
if context is not None:
context_ = []
for i in range(context.shape[0]):
context_.append(context[i].unsqueeze(0).repeat(T, 1, 1))
context_ = torch.cat(context_, dim=0)
else:
context_ = None
h = self.attn2(self.norm2(h), context=context_) + h
h = rearrange(h, "(b t) (h w) c -> b c t h w", b=B, h=H)
# temporal self
h2 = x
h2 = rearrange(h2, "b c t h w -> (b h w) t c")
h2 = self.attn1_tmp(self.norm4(h2)) # + h2
h2 = rearrange(h2, "(b h w) t c -> b c t h w", b=B, h=H, w=W)
out = h + h2
return rearrange(out, "b c t h w -> (b h w) t c")
def spatial_attn_reshape(x):
return rearrange(x, "b c t h w -> (b t) (h w) c")
def spatial_attn_reshape_back(x, b, h):
return rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
def temporal_attn_reshape(x):
return rearrange(x, "b c t h w -> (b h w) t c")
def temporal_attn_reshape_back(x, b, h, w):
return rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
def local_spatial_temporal_attn_reshape(x, window_size):
B, C, T, H, W = x.shape
NH = H // window_size
NW = W // window_size
# x = x.view(B, C, T, NH, window_size, NW, window_size)
# tokens = x.permute(0, 1, 2, 3, 5, 4, 6).contiguous()
# tokens = tokens.view(-1, window_size, window_size, C)
x = rearrange(
x,
"b c t (nh wh) (nw ww) -> b c t nh wh nw ww",
nh=NH,
nw=NW,
wh=window_size,
# # B, C, T, NH, NW, window_size, window_size
ww=window_size,
).contiguous()
# (B, NH, NW) (T, window_size, window_size) C
x = rearrange(x, "b c t nh wh nw ww -> (b nh nw) (t wh ww) c")
return x
def local_spatial_temporal_attn_reshape_back(x, window_size, b, h, w, t):
B, L, C = x.shape
NH = h // window_size
NW = w // window_size
x = rearrange(
x,
"(b nh nw) (t wh ww) c -> b c t nh wh nw ww",
b=b,
nh=NH,
nw=NW,
t=t,
wh=window_size,
ww=window_size,
)
x = rearrange(x, "b c t nh wh nw ww -> b c t (nh wh) (nw ww)")
return x
class SpatialTemporalTransformer(nn.Module):
"""
Transformer block for video-like data (5D tensor).
First, project the input (aka embedding) with NO reshape.
Then apply standard transformer action.
The 5D -> 3D reshape operation will be done in the specific attention module.
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
# Temporal stuff
temporal_length=None,
image_length=None,
use_relative_position=True,
img_video_joint_train=False,
cross_attn_on_tempoal=False,
temporal_crossattn_type=False,
order="stst",
temporalcrossfirst=False,
split_stcontext=False,
temporal_context_dim=None,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv3d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlockST(
inner_dim,
n_heads,
d_head,
dropout=dropout,
# cross attn
context_dim=context_dim,
# temporal attn
temporal_length=temporal_length,
image_length=image_length,
use_relative_position=use_relative_position,
img_video_joint_train=img_video_joint_train,
temporal_crossattn_type=temporal_crossattn_type,
order=order,
temporalcrossfirst=temporalcrossfirst,
split_stcontext=split_stcontext,
temporal_context_dim=temporal_context_dim,
**kwargs,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv3d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, x, context=None, temporal_context=None, **kwargs):
# note: if no context is given, cross-attention defaults to self-attention
assert x.dim() == 5, f"x shape = {x.shape}"
b, c, t, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
for block in self.transformer_blocks:
x = block(x, context=context, temporal_context=temporal_context, **kwargs)
x = self.proj_out(x)
return x + x_in
class STAttentionBlock2(nn.Module):
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False, # not used, only used in ResBlock
use_new_attention_order=False, # QKVAttention or QKVAttentionLegacy
temporal_length=16, # used in relative positional representation.
image_length=8, # used for image-video joint training.
# whether use relative positional representation in temporal attention.
use_relative_position=False,
img_video_joint_train=False,
# norm_type="groupnorm",
attn_norm_type="group",
use_tempoal_causal_attn=False,
):
"""
version 1: guided_diffusion implemented version
version 2: remove args input argument
"""
super().__init__()
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.temporal_length = temporal_length
self.image_length = image_length
self.use_relative_position = use_relative_position
self.img_video_joint_train = img_video_joint_train
self.attn_norm_type = attn_norm_type
assert self.attn_norm_type in ["group", "no_norm"]
self.use_tempoal_causal_attn = use_tempoal_causal_attn
if self.attn_norm_type == "group":
self.norm_s = normalization(channels)
self.norm_t = normalization(channels)
self.qkv_s = conv_nd(1, channels, channels * 3, 1)
self.qkv_t = conv_nd(1, channels, channels * 3, 1)
if self.img_video_joint_train:
mask = th.ones(
[1, temporal_length + image_length, temporal_length + image_length]
)
mask[:, temporal_length:, :] = 0
mask[:, :, temporal_length:] = 0
self.register_buffer("mask", mask)
else:
self.mask = None
if use_new_attention_order:
# split qkv before split heads
self.attention_s = QKVAttention(self.num_heads)
self.attention_t = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention_s = QKVAttentionLegacy(self.num_heads)
self.attention_t = QKVAttentionLegacy(self.num_heads)
if use_relative_position:
self.relative_position_k = RelativePosition(
num_units=channels // self.num_heads,
max_relative_position=temporal_length,
)
self.relative_position_v = RelativePosition(
num_units=channels // self.num_heads,
max_relative_position=temporal_length,
)
self.proj_out_s = zero_module(
# conv_dim, in_channels, out_channels, kernel_size
conv_nd(1, channels, channels, 1)
)
self.proj_out_t = zero_module(
# conv_dim, in_channels, out_channels, kernel_size
conv_nd(1, channels, channels, 1)
)
def forward(self, x, mask=None):
b, c, t, h, w = x.shape
# spatial
out = rearrange(x, "b c t h w -> (b t) c (h w)")
if self.attn_norm_type == "no_norm":
qkv = self.qkv_s(out)
else:
qkv = self.qkv_s(self.norm_s(out))
out = self.attention_s(qkv)
out = self.proj_out_s(out)
out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h)
x += out
# temporal
out = rearrange(x, "b c t h w -> (b h w) c t")
if self.attn_norm_type == "no_norm":
qkv = self.qkv_t(out)
else:
qkv = self.qkv_t(self.norm_t(out))
# relative positional embedding
if self.use_relative_position:
len_q = qkv.size()[-1]
len_k, len_v = len_q, len_q
k_rp = self.relative_position_k(len_q, len_k)
v_rp = self.relative_position_v(len_q, len_v) # [T,T,head_dim]
out = self.attention_t(
qkv,
rp=(k_rp, v_rp),
mask=self.mask,
use_tempoal_causal_attn=self.use_tempoal_causal_attn,
)
else:
out = self.attention_t(
qkv,
rp=None,
mask=self.mask,
use_tempoal_causal_attn=self.use_tempoal_causal_attn,
)
out = self.proj_out_t(out)
out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
return x + out
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, rp=None, mask=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
if rp is not None or mask is not None:
raise NotImplementedError
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, rp=None, mask=None, use_tempoal_causal_attn=False):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
# print('qkv', qkv.size())
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
# print('bs, self.n_heads, ch, length', bs, self.n_heads, ch, length)
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
# weight:[b,t,s] b=bs*n_heads*T
if rp is not None:
k_rp, v_rp = rp # [length, length, head_dim] [8, 8, 48]
weight2 = th.einsum(
"bct,tsc->bst", (q * scale).view(bs * self.n_heads, ch, length), k_rp
)
weight += weight2
if use_tempoal_causal_attn:
# weight = torch.tril(weight)
assert mask is None, f"Not implemented for merging two masks!"
mask = torch.tril(torch.ones(weight.shape))
else:
if mask is not None: # only keep upper-left matrix
# process mask
c, t, _ = weight.shape
if mask.shape[-1] > t:
mask = mask[:, :t, :t]
elif mask.shape[-1] < t: # pad ones
mask_ = th.zeros([c, t, t]).to(mask.device)
t_ = mask.shape[-1]
mask_[:, :t_, :t_] = mask
mask = mask_
else:
assert (
weight.shape[-1] == mask.shape[-1]
), f"weight={weight.shape}, mask={mask.shape}"
if mask is not None:
INF = -1e8 # float('-inf')
weight = weight.float().masked_fill(mask == 0, INF)
weight = F.softmax(weight.float(), dim=-1).type(
weight.dtype
) # [256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
# weight = F.softmax(weight, dim=-1)#[256, 8, 8] [b, t, t] b=bs*n_heads*h*w,t=nframes
# [256, 48, 8] [b, head_dim, t]
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
if rp is not None:
a2 = th.einsum("bts,tsc->btc", weight, v_rp).transpose(1, 2) # btc->bct
a += a2
return a.reshape(bs, -1, length)
import torch
import torch.nn as nn
from collections import OrderedDict
from extralibs.cond_api import ExtraCondition
from core.modules.x_transformer import FixedPositionalEmbedding
from core.basics import zero_module, conv_nd, avg_pool_nd
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResnetBlock(nn.Module):
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
super().__init__()
ps = ksize // 2
if in_c != out_c or sk == False:
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
self.in_conv = None
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
if sk == False:
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
self.skep = None
self.down = down
if self.down == True:
self.down_opt = Downsample(in_c, use_conv=use_conv)
def forward(self, x):
if self.down == True:
x = self.down_opt(x)
if self.in_conv is not None:
x = self.in_conv(x)
h = self.block1(x)
h = self.act(h)
h = self.block2(h)
if self.skep is not None:
return h + self.skep(x)
else:
return h + x
class Adapter(nn.Module):
def __init__(
self,
channels=[320, 640, 1280, 1280],
nums_rb=3,
cin=64,
ksize=3,
sk=True,
use_conv=True,
stage_downscale=True,
use_identity=False,
):
super(Adapter, self).__init__()
if use_identity:
self.inlayer = nn.Identity()
else:
self.inlayer = nn.PixelUnshuffle(8)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
for i in range(len(channels)):
for j in range(nums_rb):
if (i != 0) and (j == 0):
self.body.append(
ResnetBlock(
channels[i - 1],
channels[i],
down=stage_downscale,
ksize=ksize,
sk=sk,
use_conv=use_conv,
)
)
else:
self.body.append(
ResnetBlock(
channels[i],
channels[i],
down=False,
ksize=ksize,
sk=sk,
use_conv=use_conv,
)
)
self.body = nn.ModuleList(self.body)
self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
def forward(self, x):
# unshuffle
x = self.inlayer(x)
# extract features
features = []
x = self.conv_in(x)
for i in range(len(self.channels)):
for j in range(self.nums_rb):
idx = i * self.nums_rb + j
x = self.body[idx](x)
features.append(x)
return features
class PositionNet(nn.Module):
def __init__(self, input_size=(40, 64), cin=4, dim=512, out_dim=1024):
super().__init__()
self.input_size = input_size
self.out_dim = out_dim
self.down_factor = 8 # determined by the convnext backbone
feature_dim = dim
self.backbone = Adapter(
channels=[64, 128, 256, feature_dim],
nums_rb=2,
cin=cin,
stage_downscale=True,
use_identity=True,
)
self.num_tokens = (self.input_size[0] // self.down_factor) * (
self.input_size[1] // self.down_factor
)
self.pos_embedding = nn.Parameter(
torch.empty(1, self.num_tokens, feature_dim).normal_(std=0.02)
) # from BERT
self.linears = nn.Sequential(
nn.Linear(feature_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
# self.null_feature = torch.nn.Parameter(torch.zeros([feature_dim]))
def forward(self, x, mask=None):
B = x.shape[0]
# token from edge map
# x = torch.nn.functional.interpolate(x, self.input_size)
feature = self.backbone(x)[-1]
objs = feature.reshape(B, -1, self.num_tokens)
objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
"""
# expand null token
null_objs = self.null_feature.view(1,1,-1)
null_objs = null_objs.repeat(B,self.num_tokens,1)
# mask replacing
mask = mask.view(-1,1,1)
objs = objs*mask + null_objs*(1-mask)
"""
# add pos
objs = objs + self.pos_embedding
# fuse them
objs = self.linears(objs)
assert objs.shape == torch.Size([B, self.num_tokens, self.out_dim])
return objs
class PositionNet2(nn.Module):
def __init__(self, input_size=(40, 64), cin=4, dim=320, out_dim=1024):
super().__init__()
self.input_size = input_size
self.out_dim = out_dim
self.down_factor = 8 # determined by the convnext backbone
self.dim = dim
self.backbone = Adapter(
channels=[dim, dim, dim, dim],
nums_rb=2,
cin=cin,
stage_downscale=True,
use_identity=True,
)
self.pos_embedding = FixedPositionalEmbedding(dim=self.dim)
self.linears = nn.Sequential(
nn.Linear(dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
def forward(self, x, mask=None):
B = x.shape[0]
features = self.backbone(x)
token_lists = []
for feature in features:
objs = feature.reshape(B, self.dim, -1)
objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
# add pos
objs = objs + self.pos_embedding(objs)
# fuse them
objs = self.linears(objs)
token_lists.append(objs)
return token_lists
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class StyleAdapter(nn.Module):
def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
super().__init__()
scale = width**-0.5
self.transformer_layes = nn.Sequential(
*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)]
)
self.num_token = num_token
self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
self.ln_post = LayerNorm(width)
self.ln_pre = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
def forward(self, x):
# x shape [N, HW+1, C]
style_embedding = self.style_embedding + torch.zeros(
(x.shape[0], self.num_token, self.style_embedding.shape[-1]),
device=x.device,
)
x = torch.cat([x, style_embedding], dim=1)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer_layes(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, -self.num_token :, :])
x = x @ self.proj
return x
class ResnetBlock_light(nn.Module):
def __init__(self, in_c):
super().__init__()
self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
def forward(self, x):
h = self.block1(x)
h = self.act(h)
h = self.block2(h)
return h + x
class extractor(nn.Module):
def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
super().__init__()
self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
self.body = []
for _ in range(nums_rb):
self.body.append(ResnetBlock_light(inter_c))
self.body = nn.Sequential(*self.body)
self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
self.down = down
if self.down == True:
self.down_opt = Downsample(in_c, use_conv=False)
def forward(self, x):
if self.down == True:
x = self.down_opt(x)
x = self.in_conv(x)
x = self.body(x)
x = self.out_conv(x)
return x
class Adapter_light(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
super(Adapter_light, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
for i in range(len(channels)):
if i == 0:
self.body.append(
extractor(
in_c=cin,
inter_c=channels[i] // 4,
out_c=channels[i],
nums_rb=nums_rb,
down=False,
)
)
else:
self.body.append(
extractor(
in_c=channels[i - 1],
inter_c=channels[i] // 4,
out_c=channels[i],
nums_rb=nums_rb,
down=True,
)
)
self.body = nn.ModuleList(self.body)
def forward(self, x):
# unshuffle
x = self.unshuffle(x)
# extract features
features = []
for i in range(len(self.channels)):
x = self.body[i](x)
features.append(x)
return features
class CoAdapterFuser(nn.Module):
def __init__(
self, unet_channels=[320, 640, 1280, 1280], width=768, num_head=8, n_layes=3
):
super(CoAdapterFuser, self).__init__()
scale = width**0.5
self.task_embedding = nn.Parameter(scale * torch.randn(16, width))
self.positional_embedding = nn.Parameter(
scale * torch.randn(len(unet_channels), width)
)
self.spatial_feat_mapping = nn.ModuleList()
for ch in unet_channels:
self.spatial_feat_mapping.append(
nn.Sequential(
nn.SiLU(),
nn.Linear(ch, width),
)
)
self.transformer_layes = nn.Sequential(
*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)]
)
self.ln_post = LayerNorm(width)
self.ln_pre = LayerNorm(width)
self.spatial_ch_projs = nn.ModuleList()
for ch in unet_channels:
self.spatial_ch_projs.append(zero_module(nn.Linear(width, ch)))
self.seq_proj = nn.Parameter(torch.zeros(width, width))
def forward(self, features):
if len(features) == 0:
return None, None
inputs = []
for cond_name in features.keys():
task_idx = getattr(ExtraCondition, cond_name).value
if not isinstance(features[cond_name], list):
inputs.append(features[cond_name] + self.task_embedding[task_idx])
continue
feat_seq = []
for idx, feature_map in enumerate(features[cond_name]):
feature_vec = torch.mean(feature_map, dim=(2, 3))
feature_vec = self.spatial_feat_mapping[idx](feature_vec)
feat_seq.append(feature_vec)
feat_seq = torch.stack(feat_seq, dim=1) # Nx4xC
feat_seq = feat_seq + self.task_embedding[task_idx]
feat_seq = feat_seq + self.positional_embedding
inputs.append(feat_seq)
x = torch.cat(inputs, dim=1) # NxLxC
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer_layes(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x)
ret_feat_map = None
ret_feat_seq = None
cur_seq_idx = 0
for cond_name in features.keys():
if not isinstance(features[cond_name], list):
length = features[cond_name].size(1)
transformed_feature = features[cond_name] * (
(x[:, cur_seq_idx : cur_seq_idx + length] @ self.seq_proj) + 1
)
if ret_feat_seq is None:
ret_feat_seq = transformed_feature
else:
ret_feat_seq = torch.cat([ret_feat_seq, transformed_feature], dim=1)
cur_seq_idx += length
continue
length = len(features[cond_name])
transformed_feature_list = []
for idx in range(length):
alpha = self.spatial_ch_projs[idx](x[:, cur_seq_idx + idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1) + 1
transformed_feature_list.append(features[cond_name][idx] * alpha)
if ret_feat_map is None:
ret_feat_map = transformed_feature_list
else:
ret_feat_map = list(
map(lambda x, y: x + y, ret_feat_map, transformed_feature_list)
)
cur_seq_idx += length
assert cur_seq_idx == x.size(1)
return ret_feat_map, ret_feat_seq
import torch
import torch.nn as nn
import kornia
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
import open_clip
from core.common import autocast
from utils.utils import count_params
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0.0 and not disable_dropout:
mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
# 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc_class = self.n_classes - 1
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
):
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(
input_ids=tokens, output_hidden_states=self.layer == "hidden"
)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class ClipImageEmbedder(nn.Module):
def __init__(
self,
model,
jit=False,
device="cuda" if torch.cuda.is_available() else "cpu",
antialias=True,
ucg_rate=0.0,
):
super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# re-normalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0.0 and not no_dropout:
out = (
torch.bernoulli(
(1.0 - self.ucg_rate) * torch.ones(out.shape[0], device=out.device)
)[:, None]
* out
)
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
# "pooled",
"last",
"penultimate",
]
def __init__(
self,
arch="ViT-H-14",
version=None,
device="cuda",
max_length=77,
freeze=True,
layer="last",
):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version
)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if (
self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(
self,
arch="ViT-H-14",
version=None,
device="cuda",
max_length=77,
freeze=True,
layer="pooled",
antialias=True,
ucg_rate=0.0,
):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version
)
del model.transformer
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0.0 and not no_dropout:
z = (
torch.bernoulli(
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
)[:, None]
* z
)
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(
self,
arch="ViT-H-14",
version=None,
device="cuda",
freeze=True,
layer="pooled",
antialias=True,
):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
pretrained=version,
)
del model.transformer
self.model = model
self.device = device
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def forward(self, image, no_dropout=False):
# image: b c h w
z = self.encode_with_vision_transformer(image)
return z
def encode_with_vision_transformer(self, x):
x = self.preprocess(x)
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.model.visual.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(
x.shape[0],
x.shape[1],
self.model.visual.grid_size[0],
self.model.visual.patch_size[0],
self.model.visual.grid_size[1],
self.model.visual.patch_size[1],
)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(
x.shape[0],
self.model.visual.grid_size[0] * self.model.visual.grid_size[1],
-1,
)
x = self.model.visual.patchnorm_pre_ln(x)
x = self.model.visual.conv1(x)
else:
x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
# shape = [*, width, grid ** 2]
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[
self.model.visual.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.model.visual.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.model.visual.patch_dropout(x)
x = self.model.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.model.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
return x
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
import math
import torch
import torch.nn as nn
from einops import rearrange, repeat
class ImageProjModel(nn.Module):
"""Projection Model"""
def __init__(
self,
cross_attention_dim=1024,
clip_embeddings_dim=1024,
clip_extra_context_tokens=4,
):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = nn.Linear(
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
# embeds = image_embeds
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
# More stable with f16 than dividing afterwards
weight = (q * scale) @ (k * scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
video_length=None,
):
super().__init__()
self.num_queries = num_queries
self.video_length = video_length
if video_length is not None:
num_queries = num_queries * video_length
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1) # B (T L) C
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents) # B L C or B (T L) C
return latents
class CameraPoseQueryTransformer(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
num_views=None,
use_multi_view_attention=True,
):
super().__init__()
self.num_queries = num_queries
self.num_views = num_views
assert num_views is not None, "video_length must be given."
self.use_multi_view_attention = use_multi_view_attention
self.camera_pose_embedding_layers = nn.Sequential(
nn.Linear(12, dim),
nn.SiLU(),
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
)
nn.init.zeros_(self.camera_pose_embedding_layers[-1].weight)
nn.init.zeros_(self.camera_pose_embedding_layers[-1].bias)
self.latents = nn.Parameter(
torch.randn(1, num_views * num_queries, dim) / dim**0.5
)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x, camera_poses):
# camera_poses: (b, t, 12)
batch_size, num_views, _ = camera_poses.shape
# latents: (1, t*q, d) -> (b, t*q, d)
latents = self.latents.repeat(batch_size, 1, 1)
x = self.proj_in(x)
# camera_poses: (b*t, 12)
camera_poses = rearrange(camera_poses, "b t d -> (b t) d", t=num_views)
camera_poses = self.camera_pose_embedding_layers(
camera_poses
) # camera_poses: (b*t, d)
# camera_poses: (b, t, d)
camera_poses = rearrange(camera_poses, "(b t) d -> b t d", t=num_views)
# camera_poses: (b, t*q, d)
camera_poses = repeat(camera_poses, "b t d -> b (t q) d", q=self.num_queries)
latents = latents + camera_poses # b, t*q, d
latents = rearrange(
latents,
"b (t q) d -> (b t) q d",
b=batch_size,
t=num_views,
q=self.num_queries,
) # (b*t, q, d)
_, x_seq_size, _ = x.shape
for layer_idx, (attn, ff) in enumerate(self.layers):
if self.use_multi_view_attention and layer_idx % 2 == 1:
# latents: (b*t, q, d)
latents = rearrange(
latents,
"(b t) q d -> b (t q) d",
b=batch_size,
t=num_views,
q=self.num_queries,
)
# x: (b*t, s, d)
x = rearrange(
x, "(b t) s d -> b (t s) d", b=batch_size, t=num_views, s=x_seq_size
)
# print("After rearrange: latents.shape=", latents.shape)
# print("After rearrange: x.shape=", camera_poses.shape)
latents = attn(x, latents) + latents
latents = ff(latents) + latents
if self.use_multi_view_attention and layer_idx % 2 == 1:
# latents: (b*q, t, d)
latents = rearrange(
latents,
"b (t q) d -> (b t) q d",
b=batch_size,
t=num_views,
q=self.num_queries,
)
# x: (b*s, t, d)
x = rearrange(
x, "b (t s) d -> (b t) s d", b=batch_size, t=num_views, s=x_seq_size
)
latents = self.proj_out(latents)
latents = self.norm_out(latents) # B L C or B (T L) C
return latents
# pytorch_diffusion + derived encoder decoder
import math
import torch
import numpy as np
import torch.nn as nn
from einops import rearrange
from utils.utils import instantiate_from_config
from core.modules.attention import LinearAttention
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w) # bcl
q = q.permute(0, 2, 1) # bcl -> blc l=hw
k = k.reshape(b, c, h * w) # bcl
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
self.in_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
self.in_channels = in_channels
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
def get_timestep_embedding(time_steps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(time_steps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=time_steps.device)
emb = time_steps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class Model(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla",
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, x, t=None, context=None):
# assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb
)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
# timestep embedding
temb = None
# print(f'encoder-input={x.shape}')
# downsampling
hs = [self.conv_in(x)]
# print(f'encoder-conv in feat={hs[0].shape}')
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
# print(f'encoder-down feat={h.shape}')
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
# print(f'encoder-downsample (input)={hs[-1].shape}')
hs.append(self.down[i_level].downsample(hs[-1]))
# print(f'encoder-downsample (output)={hs[-1].shape}')
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
# print(f'encoder-mid1 feat={h.shape}')
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# print(f'encoder-mid2 feat={h.shape}')
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
# print(f'end feat={h.shape}')
return h
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignored_kwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# print("AE working on z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# print(f'decoder-input={z.shape}')
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# print(f'decoder-conv in feat={h.shape}')
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# print(f'decoder-mid feat={h.shape}')
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
# print(f'decoder-up feat={h.shape}')
if i_level != 0:
h = self.up[i_level].upsample(h)
# print(f'decoder-upsample feat={h.shape}')
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
# print(f'decoder-conv_out feat={h.shape}')
if self.tanh_out:
h = torch.tanh(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList(
[
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(
in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0,
),
ResnetBlock(
in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0,
dropout=0.0,
),
ResnetBlock(
in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0,
),
nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True),
]
)
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
ch,
num_res_blocks,
resolution,
ch_mult=(2, 2),
dropout=0.0,
):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(
in_channels, mid_channels, kernel_size=3, stride=1, padding=1
)
self.res_block1 = nn.ModuleList(
[
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0,
)
for _ in range(depth)
]
)
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList(
[
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0,
)
for _ in range(depth)
]
)
self.conv_out = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(
x,
size=(
int(round(x.shape[2] * self.factor)),
int(round(x.shape[3] * self.factor)),
),
)
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
x = self.conv_out(x)
return x
class MergedRescaleEncoder(nn.Module):
def __init__(
self,
in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(
in_channels=in_channels,
num_res_blocks=num_res_blocks,
ch=ch,
ch_mult=ch_mult,
z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth,
)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(
self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1,
):
super().__init__()
tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(
out_ch=out_ch,
z_channels=tmp_chn,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch,
)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth,
)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size // in_size)) + 1
factor_up = 1.0 + (out_size % in_size)
print(
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
)
self.rescaler = LatentRescaler(
factor=factor_up,
in_channels=in_channels,
mid_channels=2 * in_channels,
out_channels=in_channels,
)
self.decoder = Decoder(
out_ch=out_channels,
resolution=out_size,
z_channels=in_channels,
num_res_blocks=2,
attn_resolutions=[],
in_channels=None,
ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)],
)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
print(
f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
)
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=4, stride=2, padding=1
)
def forward(self, x, scale_factor=1.0):
if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
)
return x
class FirstStagePostProcessor(nn.Module):
def __init__(
self,
ch_mult: list,
in_channels,
pretrained_model: nn.Module = None,
reshape=False,
n_channels=None,
dropout=0.0,
pretrained_config=None,
):
super().__init__()
if pretrained_config is None:
assert (
pretrained_model is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model
else:
assert (
pretrained_config is not None
), 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config)
self.do_reshape = reshape
if n_channels is None:
n_channels = self.pretrained_model.encoder.ch
self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
self.proj = nn.Conv2d(
in_channels, n_channels, kernel_size=3, stride=1, padding=1
)
blocks = []
downs = []
ch_in = n_channels
for m in ch_mult:
blocks.append(
ResnetBlock(
in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
)
)
ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs)
def instantiate_pretrained(self, config):
model = instantiate_from_config(config)
self.pretrained_model = model.eval()
# self.pretrained_model.train = False
for param in self.pretrained_model.parameters():
param.requires_grad = False
@torch.no_grad()
def encode_with_pretrained(self, x):
c = self.pretrained_model.encode(x)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
return c
def forward(self, x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
for submodel, downmodel in zip(self.model, self.downsampler):
z = submodel(z, temb=None)
z = downmodel(z)
if self.do_reshape:
z = rearrange(z, "b c h w -> b (h w) c")
return z
from functools import partial
from abc import abstractmethod
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from core.models.utils_diffusion import timestep_embedding
from core.common import gradient_checkpoint
from core.basics import zero_module, conv_nd, linear, avg_pool_nd, normalization
from core.modules.attention import SpatialTransformer, TemporalTransformer
TASK_IDX_IMAGE = 0
TASK_IDX_RAY = 1
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(
self, x, emb, context=None, batch_size=None, with_lora=False, time_steps=None
):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb, batch_size=batch_size)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, with_lora=with_lora)
elif isinstance(layer, TemporalTransformer):
x = rearrange(x, "(b f) c h w -> b c f h w", b=batch_size)
x = layer(x, context, with_lora=with_lora, time_steps=time_steps)
x = rearrange(x, "b c f h w -> (b f) c h w")
else:
x = layer(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
:param use_temporal_conv: if True, use the temporal convolution.
:param use_image_dataset: if True, the temporal parameters will not be optimized.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
use_conv=False,
up=False,
down=False,
use_temporal_conv=False,
tempspatial_aware=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.use_temporal_conv = use_temporal_conv
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock(
self.out_channels,
self.out_channels,
dropout=0.1,
spatial_aware=tempspatial_aware,
)
def forward(self, x, emb, batch_size=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
input_tuple = (x, emb)
if batch_size:
forward_batchsize = partial(self._forward, batch_size=batch_size)
return gradient_checkpoint(
forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint
)
return gradient_checkpoint(
self._forward, input_tuple, self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb, batch_size=None):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
if self.use_temporal_conv and batch_size:
h = rearrange(h, "(b t) c h w -> b c t h w", b=batch_size)
h = self.temopral_conv(h)
h = rearrange(h, "b c t h w -> (b t) c h w")
return h
class TemporalConvBlock(nn.Module):
def __init__(
self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False
):
super(TemporalConvBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
self.out_channels = out_channels
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_channels),
nn.SiLU(),
nn.Conv3d(
in_channels, out_channels, th_kernel_shape, padding=th_padding_shape
),
)
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(
out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape
),
)
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(
out_channels, in_channels, th_kernel_shape, padding=th_padding_shape
),
)
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(
out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape
),
)
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return identity + x
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: in_channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0.0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
context_dim=None,
use_scale_shift_norm=False,
resblock_updown=False,
num_heads=-1,
num_head_channels=-1,
transformer_depth=1,
use_linear=False,
use_checkpoint=False,
temporal_conv=False,
tempspatial_aware=False,
temporal_attention=True,
use_relative_position=True,
use_causal_attention=False,
temporal_length=None,
use_fp16=False,
addition_attention=False,
temporal_selfatt_only=True,
image_cross_attention=False,
image_cross_attention_scale_learnable=False,
default_fs=4,
fs_condition=False,
use_spatial_temporal_attention=False,
# >>> Extra Ray Options
use_addition_ray_output_head=False,
ray_channels=6,
use_lora_for_rays_in_output_blocks=False,
use_task_embedding=False,
use_ray_decoder=False,
use_ray_decoder_residual=False,
full_spatial_temporal_attention=False,
enhance_multi_view_correspondence=False,
camera_pose_condition=False,
use_feature_alignment=False,
):
super(UNetModel, self).__init__()
if num_heads == -1:
assert (
num_head_channels != -1
), "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1:
assert (
num_heads != -1
), "Either num_heads or num_head_channels has to be set"
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.temporal_attention = temporal_attention
time_embed_dim = model_channels * 4
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
temporal_self_att_only = True
self.addition_attention = addition_attention
self.temporal_length = temporal_length
self.image_cross_attention = image_cross_attention
self.image_cross_attention_scale_learnable = (
image_cross_attention_scale_learnable
)
self.default_fs = default_fs
self.fs_condition = fs_condition
self.use_spatial_temporal_attention = use_spatial_temporal_attention
# >>> Extra Ray Options
self.use_addition_ray_output_head = use_addition_ray_output_head
self.use_lora_for_rays_in_output_blocks = use_lora_for_rays_in_output_blocks
if self.use_lora_for_rays_in_output_blocks:
assert (
use_addition_ray_output_head
), "`use_addition_ray_output_head` is required to be True when using LoRA for rays in output blocks."
assert (
not use_task_embedding
), "`use_task_embedding` cannot be True when `use_lora_for_rays_in_output_blocks` is enabled."
if self.use_addition_ray_output_head:
print("Using additional ray output head...")
assert (self.out_channels == 4) or (
4 + ray_channels == self.out_channels
), f"`out_channels`={out_channels} is invalid."
self.out_channels = 4
out_channels = 4
self.ray_channels = ray_channels
self.use_ray_decoder = use_ray_decoder
if use_ray_decoder:
assert (
not use_task_embedding
), "`use_task_embedding` cannot be True when `use_ray_decoder_layers` is enabled."
assert (
use_addition_ray_output_head
), "`use_addition_ray_output_head` must be True when `use_ray_decoder_layers` is enabled."
self.use_ray_decoder_residual = use_ray_decoder_residual
# >>> Time/Task Embedding Blocks
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if fs_condition:
self.fps_embedding = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
if camera_pose_condition:
self.camera_pose_condition = True
self.camera_pose_embedding = nn.Sequential(
linear(12, model_channels),
nn.SiLU(),
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
nn.init.zeros_(self.camera_pose_embedding[-1].weight)
nn.init.zeros_(self.camera_pose_embedding[-1].bias)
self.use_task_embedding = use_task_embedding
if use_task_embedding:
assert (
not use_lora_for_rays_in_output_blocks
), "`use_lora_for_rays_in_output_blocks` and `use_task_embedding` cannot be True at the same time."
assert (
use_addition_ray_output_head
), "`use_addition_ray_output_head` is required to be True when `use_task_embedding` is enabled."
self.task_embedding = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
nn.init.zeros_(self.task_embedding[-1].weight)
nn.init.zeros_(self.task_embedding[-1].bias)
self.task_parameters = nn.ParameterList(
[
nn.Parameter(
torch.zeros(size=[model_channels], requires_grad=True)
),
nn.Parameter(
torch.zeros(size=[model_channels], requires_grad=True)
),
]
)
# >>> Input Block
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
if self.addition_attention:
self.init_attn = TimestepEmbedSequential(
TemporalTransformer(
model_channels,
n_heads=8,
d_head=num_head_channels,
depth=transformer_depth,
context_dim=context_dim,
use_checkpoint=use_checkpoint,
only_self_att=temporal_selfatt_only,
causal_attention=False,
relative_position=use_relative_position,
temporal_length=temporal_length,
)
)
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
image_cross_attention=self.image_cross_attention,
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
)
)
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv,
),
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
image_cross_attention=self.image_cross_attention,
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
),
]
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length,
)
)
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv,
)
)
# >>> Middle Block
self.middle_block = TimestepEmbedSequential(*layers)
# >>> Ray Decoder
if use_ray_decoder:
self.ray_decoder_blocks = nn.ModuleList([])
# >>> Output Block
is_first_layer = True
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv,
)
]
if use_ray_decoder:
if self.use_ray_decoder_residual:
ray_residual_ch = ich
else:
ray_residual_ch = 0
ray_decoder_layers = [
ResBlock(
(ch if is_first_layer else (ch // 10)) + ray_residual_ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels // 10,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=True,
)
]
is_first_layer = False
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
image_cross_attention=self.image_cross_attention,
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
enable_lora=self.use_lora_for_rays_in_output_blocks,
)
)
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length,
use_extra_spatial_temporal_self_attention=use_spatial_temporal_attention,
enable_lora=self.use_lora_for_rays_in_output_blocks,
full_spatial_temporal_attention=full_spatial_temporal_attention,
enhance_multi_view_correspondence=enhance_multi_view_correspondence,
)
)
if level and i == num_res_blocks:
out_ch = ch
# out_ray_ch = ray_ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
if use_ray_decoder:
ray_decoder_layers.append(
ResBlock(
ch // 10,
time_embed_dim,
dropout,
out_channels=out_ch // 10,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(
ch // 10,
conv_resample,
dims=dims,
out_channels=out_ch // 10,
)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
if use_ray_decoder:
self.ray_decoder_blocks.append(
TimestepEmbedSequential(*ray_decoder_layers)
)
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.use_addition_ray_output_head:
ray_model_channels = model_channels // 10
self.ray_output_head = nn.Sequential(
normalization(ray_model_channels),
nn.SiLU(),
conv_nd(dims, ray_model_channels, ray_model_channels, 3, padding=1),
nn.SiLU(),
conv_nd(dims, ray_model_channels, ray_model_channels, 3, padding=1),
nn.SiLU(),
zero_module(
conv_nd(dims, ray_model_channels, self.ray_channels, 3, padding=1)
),
)
self.use_feature_alignment = use_feature_alignment
if self.use_feature_alignment:
self.feature_alignment_adapter = FeatureAlignmentAdapter(
time_embed_dim=time_embed_dim, use_checkpoint=use_checkpoint
)
def forward(
self,
x,
time_steps,
context=None,
features_adapter=None,
fs=None,
task_idx=None,
camera_poses=None,
return_input_block_features=False,
return_middle_feature=False,
return_output_block_features=False,
**kwargs,
):
intermediate_features = {}
if return_input_block_features:
intermediate_features["input"] = []
if return_output_block_features:
intermediate_features["output"] = []
b, t, _, _, _ = x.shape
t_emb = timestep_embedding(
time_steps, self.model_channels, repeat_only=False
).type(x.dtype)
emb = self.time_embed(t_emb)
# repeat t times for context [(b t) 77 768] & time embedding
# check if we use per-frame image conditioning
_, l_context, _ = context.shape
if l_context == 77 + t * 16: # !!! HARD CODE here
context_text, context_img = context[:, :77, :], context[:, 77:, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img, "b (t l) c -> (b t) l c", t=t)
context = torch.cat([context_text, context_img], dim=1)
else:
context = context.repeat_interleave(repeats=t, dim=0)
emb = emb.repeat_interleave(repeats=t, dim=0)
# always in shape (b t) c h w, except for temporal layer
x = rearrange(x, "b t c h w -> (b t) c h w")
# combine emb
if self.fs_condition:
if fs is None:
fs = torch.tensor(
[self.default_fs] * b, dtype=torch.long, device=x.device
)
fs_emb = timestep_embedding(
fs, self.model_channels, repeat_only=False
).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fs_embed
if self.camera_pose_condition:
# camera_poses: (b, t, 12)
camera_poses = rearrange(camera_poses, "b t x y -> (b t) (x y)") # x=3, y=4
camera_poses_embed = self.camera_pose_embedding(camera_poses)
emb = emb + camera_poses_embed
if self.use_task_embedding:
assert (
task_idx is not None
), "`task_idx` should not be None when `use_task_embedding` is enabled."
task_embed = self.task_embedding(
self.task_parameters[task_idx]
.reshape(1, self.model_channels)
.repeat(b, 1)
)
task_embed = task_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + task_embed
h = x.type(self.dtype)
adapter_idx = 0
hs = []
for _id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if _id == 0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b)
# plug-in adapter features
if ((_id + 1) % 3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
hs.append(h)
if return_input_block_features:
intermediate_features["input"].append(h)
if features_adapter is not None:
assert len(features_adapter) == adapter_idx, "Wrong features_adapter"
h = self.middle_block(h, emb, context=context, batch_size=b)
if return_middle_feature:
intermediate_features["middle"] = h
if self.use_feature_alignment:
feature_alignment_output = self.feature_alignment_adapter(
hs[2], hs[5], hs[8], emb=emb
)
# >>> Output Blocks Forward
if self.use_ray_decoder:
h_original = h
h_ray = h
for original_module, ray_module in zip(
self.output_blocks, self.ray_decoder_blocks
):
cur_hs = hs.pop()
h_original = torch.cat([h_original, cur_hs], dim=1)
h_original = original_module(
h_original,
emb,
context=context,
batch_size=b,
time_steps=time_steps,
)
if self.use_ray_decoder_residual:
h_ray = torch.cat([h_ray, cur_hs], dim=1)
h_ray = ray_module(h_ray, emb, context=context, batch_size=b)
if return_output_block_features:
print(
"return_output_block_features: h_original.shape=",
h_original.shape,
)
intermediate_features["output"].append(h_original.detach())
h_original = h_original.type(x.dtype)
h_ray = h_ray.type(x.dtype)
y_original = self.out(h_original)
y_ray = self.ray_output_head(h_ray)
y = torch.cat([y_original, y_ray], dim=1)
else:
if self.use_lora_for_rays_in_output_blocks:
middle_h = h
h_original = middle_h
h_lora = middle_h
for output_idx, module in enumerate(self.output_blocks):
cur_hs = hs.pop()
h_original = torch.cat([h_original, cur_hs], dim=1)
h_original = module(
h_original, emb, context=context, batch_size=b, with_lora=False
)
h_lora = torch.cat([h_lora, cur_hs], dim=1)
h_lora = module(
h_lora, emb, context=context, batch_size=b, with_lora=True
)
h_original = h_original.type(x.dtype)
h_lora = h_lora.type(x.dtype)
y_original = self.out(h_original)
y_lora = self.ray_output_head(h_lora)
y = torch.cat([y_original, y_lora], dim=1)
else:
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b)
h = h.type(x.dtype)
if self.use_task_embedding:
# Seperated Input (Branch Control in CPU)
# Serial Execution (GPU Vectorization Pending)
if task_idx == TASK_IDX_IMAGE:
y = self.out(h)
elif task_idx == TASK_IDX_RAY:
y = self.ray_output_head(h)
else:
raise NotImplementedError(f"Unsupported `task_idx`: {task_idx}")
else:
# Output ray and images at the same forward
y = self.out(h)
if self.use_addition_ray_output_head:
y_ray = self.ray_output_head(h)
y = torch.cat([y, y_ray], dim=1)
# reshape back to (b c t h w)
y = rearrange(y, "(b t) c h w -> b t c h w", b=b)
if (
return_input_block_features
or return_output_block_features
or return_middle_feature
):
return y, intermediate_features
# Assume intermediate features are only request during non-training scenarios (e.g., feature visualization)
if self.use_feature_alignment:
return y, feature_alignment_output
return y
class FeatureAlignmentAdapter(torch.nn.Module):
def __init__(self, time_embed_dim, use_checkpoint, dropout=0.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.channel_adapter_conv_16 = torch.nn.Conv2d(
in_channels=1280, out_channels=320, kernel_size=1
)
self.channel_adapter_conv_32 = torch.nn.Conv2d(
in_channels=640, out_channels=320, kernel_size=1
)
self.upsampler_x2 = torch.nn.UpsamplingBilinear2d(scale_factor=2)
self.upsampler_x4 = torch.nn.UpsamplingBilinear2d(scale_factor=4)
self.res_block = ResBlock(
320 * 3,
time_embed_dim,
dropout,
out_channels=32 * 3,
dims=2,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=False,
)
self.final_conv = conv_nd(
dims=2, in_channels=32 * 3, out_channels=6, kernel_size=1
)
def forward(self, feature_64, feature_32, feature_16, emb):
feature_16_adapted = self.channel_adapter_conv_16(feature_16)
feature_32_adapted = self.channel_adapter_conv_32(feature_32)
feature_16_upsampled = self.upsampler_x4(feature_16_adapted)
feature_32_upsampled = self.upsampler_x2(feature_32_adapted)
feature_all = torch.concat(
[feature_16_upsampled, feature_32_upsampled, feature_64], dim=1
)
# bt, 3, h, w
return self.final_conv(self.res_block(feature_all, emb=emb))
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, token_tensors):
# input: (B,C,H,W)
x = token_tensors
h, w = x.shape[-2:]
identity_map = torch.ones((h, w), device=x.device)
y_embed = identity_map.cumsum(0, dtype=torch.float32)
x_embed = identity_map.cumsum(1, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[-1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
).flatten(2)
pos_y = torch.stack(
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
).flatten(2)
pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return batch_pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, n_pos_x=16, n_pos_y=16, num_pos_feats=64):
super().__init__()
self.row_embed = nn.Embedding(n_pos_y, num_pos_feats)
self.col_embed = nn.Embedding(n_pos_x, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, token_tensors):
# input: (B,C,H,W)
x = token_tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
).permute(2, 0, 1)
batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return batch_pos
def build_position_encoding(num_pos_feats=64, n_pos_x=16, n_pos_y=16, is_learned=False):
if is_learned:
position_embedding = PositionEmbeddingLearned(n_pos_x, n_pos_y, num_pos_feats)
else:
position_embedding = PositionEmbeddingSine(num_pos_feats, normalize=True)
return position_embedding
from functools import partial
from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat
import torch
from torch import nn, einsum
import torch.nn.functional as F
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"])
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
self.init_()
def init_(self):
nn.init.normal_(self.emb.weight, std=0.02)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
return self.emb(n)[None, :, :]
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = (
torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ offset
)
sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
)
return kwargs_without_prefix, kwargs
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.value, *rest)
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.g, *rest)
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class Residual(nn.Module):
def forward(self, x, residual):
return x + residual
class GRUGating(nn.Module):
def __init__(self, dim):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")
)
return gated_output.reshape_as(x)
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.0,
on_attn=False,
):
super().__init__()
if use_entmax15:
raise NotImplementedError(
"Check out entmax activation instead of softmax activation!"
)
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
self.mask = mask
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.sparse_topk = sparse_topk
self.attn_fn = F.softmax
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.attn_on_attn = on_attn
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
if on_attn
else nn.Linear(inner_dim, dim)
)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None,
):
b, n, _, h, talking_heads, device = (
*x.shape,
self.heads,
self.talking_heads,
x.device,
)
kv_input = default(context, x)
q_input = x
k_input = kv_input
v_input = kv_input
if exists(mem):
k_input = torch.cat((mem, k_input), dim=-2)
v_input = torch.cat((mem, v_input), dim=-2)
if exists(sinusoidal_emb):
offset = k_input.shape[-2] - q_input.shape[-2]
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
k_input = k_input + sinusoidal_emb(k_input)
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(
k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()
)
q_mask = rearrange(q_mask, "b i -> b () i ()")
k_mask = rearrange(k_mask, "b j -> b () () j")
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(
lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)
)
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
mask_value = max_neg_value(dots)
if exists(prev_attn):
dots = dots + prev_attn
pre_softmax_attn = dots
if talking_heads:
dots = einsum(
"b h i j, h k -> b k i j", dots, self.pre_softmax_proj
).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
if exists(input_mask):
dots.masked_fill_(~input_mask, mask_value)
del input_mask
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
top, _ = dots.topk(self.sparse_topk, dim=-1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, mask_value)
del mask
attn = self.attn_fn(dots, dim=-1)
post_softmax_attn = attn
attn = self.dropout(attn)
if talking_heads:
attn = einsum(
"b h i j, h k -> b k i j", attn, self.post_softmax_proj
).contiguous()
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
)
return self.to_out(out), intermediates
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs,
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = (
FixedPositionalEmbedding(dim) if position_infused_attn else None
)
self.rotary_pos_emb = always(None)
assert (
rel_pos_num_buckets <= rel_pos_max_distance
), "number of relative position buckets must be less than the relative position max distance"
self.rel_pos = None
self.pre_norm = pre_norm
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_fn = partial(norm_class, dim)
norm_fn = nn.Identity if use_rezero else norm_fn
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ("a", "c", "f")
elif cross_attend and only_cross:
default_block = ("c", "f")
else:
default_block = ("a", "f")
if macaron:
default_block = ("f",) + default_block
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, "par ratio out of range"
default_block = tuple(filter(not_equals("f"), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert (
len(default_block) <= par_width
), "default block is too large for par_ratio"
par_block = default_block + ("f",) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ("f",) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert (
sandwich_coef > 0 and sandwich_coef <= depth
), "sandwich coefficient should be less than the depth"
layer_types = (
("a",) * sandwich_coef
+ default_block * (depth - sandwich_coef)
+ ("f",) * sandwich_coef
)
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
for layer_type in self.layer_types:
if layer_type == "a":
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == "c":
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == "f":
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f"invalid layer type {layer_type}")
if isinstance(layer, Attention) and exists(branch_fn):
layer = branch_fn(layer)
if gate_residual:
residual_fn = GRUGating(dim)
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False,
):
hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = ind == (len(self.layers) - 1)
if layer_type == "a":
hiddens.append(x)
layer_mem = mems.pop(0)
residual = x
if self.pre_norm:
x = norm(x)
if layer_type == "a":
out, inter = block(
x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem,
)
elif layer_type == "c":
out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn,
)
elif layer_type == "f":
out = block(x)
x = residual_fn(out, residual)
if layer_type in ("a", "c"):
intermediates.append(inter)
if layer_type == "a" and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == "c" and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if not self.pre_norm and not is_last:
x = norm(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens, attn_intermediates=intermediates
)
return x, intermediates
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert "causal" not in kwargs, "cannot set causality on encoder"
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.0,
emb_dropout=0.0,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True,
):
super().__init__()
assert isinstance(
attn_layers, AttentionLayers
), "attention layers must be one of Encoder or Decoder"
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = (
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = (
nn.Linear(dim, num_tokens)
if not tie_embedding
else lambda t: t @ self.token_emb.weight.t()
)
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
if hasattr(attn_layers, "num_memory_tokens"):
attn_layers.num_memory_tokens = num_memory_tokens
def init_(self):
nn.init.normal_(self.token_emb.weight, std=0.02)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
**kwargs,
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(x)
x = self.emb_dropout(x)
x = self.project_emb(x)
if num_mem > 0:
mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
out = self.to_logits(x) if not return_embeddings else x
if return_mems:
hiddens = intermediates.hiddens
new_mems = (
list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens)))
if exists(mems)
else hiddens
)
new_mems = list(
map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
)
return out, new_mems
if return_attn:
attn_maps = list(
map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
)
return out, attn_maps
return out
from core.models.samplers.ddim import DDIMSampler
import glob
import json
import os
import sys
from collections import OrderedDict
import numpy as np
import torch
import torchvision
from PIL import Image
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
def batch_ddim_sampling(
model,
cond,
noise_shape,
n_samples=1,
ddim_steps=50,
ddim_eta=1.0,
cfg_scale=1.0,
temporal_cfg_scale=None,
use_cat_ucg=False,
**kwargs,
):
ddim_sampler = DDIMSampler(model)
uncond_type = model.uncond_type
batch_size = noise_shape[0]
# construct unconditional guidance
if cfg_scale != 1.0:
if uncond_type == "empty_seq":
prompts = batch_size * [""]
# prompts = N * T * [""] # if is_image_batch=True
uc_emb = model.get_learned_conditioning(prompts)
elif uncond_type == "zero_embed":
c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
uc_emb = torch.zeros_like(c_emb)
# process image condition
if hasattr(model, "embedder"):
uc_img = torch.zeros(noise_shape[0], 3, 224, 224).to(model.device)
# img: b c h w >> b l c
uc_img = model.get_image_embeds(uc_img)
uc_emb = torch.cat([uc_emb, uc_img], dim=1)
if isinstance(cond, dict):
uc = {key: cond[key] for key in cond.keys()}
uc.update({"c_crossattn": [uc_emb]})
# special CFG for frame concatenation
if use_cat_ucg and hasattr(model, "cond_concat") and model.cond_concat:
uc_cat = torch.zeros(
noise_shape[0], model.cond_channels, *noise_shape[2:]
).to(model.device)
uc.update({"c_concat": [uc_cat]})
else:
uc = [uc_emb]
else:
uc = None
# uc.update({'fps': torch.tensor([-4]*batch_size).to(model.device).long()})
# sampling
noise = torch.randn(noise_shape, device=model.device)
# x_T = repeat(noise[:,:,:1,:,:], 'b c l h w -> b c (l t) h w', t=noise_shape[2])
# x_T = 0.2 * x_T + 0.8 * torch.randn(noise_shape, device=model.device)
x_T = None
batch_variants = []
# batch_variants1, batch_variants2 = [], []
for _ in range(n_samples):
if ddim_sampler is not None:
samples, _ = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=noise_shape[0],
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
temporal_length=noise_shape[2],
conditional_guidance_scale_temporal=temporal_cfg_scale,
x_T=x_T,
**kwargs,
)
# reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants.append(batch_images)
"""
pred_x0_list, x_iter_list = _['pred_x0'], _['x_inter']
steps = [0, 15, 25, 30, 35, 40, 43, 46, 49, 50]
for nn in steps:
pred_x0 = pred_x0_list[nn]
x_iter = x_iter_list[nn]
batch_images_x0 = model.decode_first_stage(pred_x0)
batch_variants1.append(batch_images_x0)
batch_images_xt = model.decode_first_stage(x_iter)
batch_variants2.append(batch_images_xt)
"""
# batch, <samples>, c, t, h, w
batch_variants = torch.stack(batch_variants, dim=1)
# batch_variants1 = torch.stack(batch_variants1, dim=1)
# batch_variants2 = torch.stack(batch_variants2, dim=1)
# return batch_variants1, batch_variants2
return batch_variants
def batch_sliding_interpolation(
model,
cond,
base_videos,
base_stride,
noise_shape,
n_samples=1,
ddim_steps=50,
ddim_eta=1.0,
cfg_scale=1.0,
temporal_cfg_scale=None,
**kwargs,
):
"""
Current implementation has a flaw: the inter-episode keyframe is used as pre-last and cur-first, so keyframe repeated.
For example, cond_frames=[0,4,7], model.temporal_length=8, base_stride=4, then
base frame : 0 4 8 12 16 20 24 28
interplation: (0~7) (8~15) (16~23) (20~27)
"""
b, c, t, h, w = noise_shape
base_z0 = model.encode_first_stage(base_videos)
unit_length = model.temporal_length
n_base_frames = base_videos.shape[2]
n_refs = len(model.cond_frames)
sliding_steps = (n_base_frames - 1) // (n_refs - 1)
sliding_steps = (
sliding_steps + 1 if (n_base_frames - 1) % (n_refs - 1) > 0 else sliding_steps
)
cond_mask = model.cond_mask.to("cuda")
proxy_z0 = torch.zeros((b, c, unit_length, h, w), dtype=torch.float32).to("cuda")
batch_samples = None
last_offset = None
for idx in range(sliding_steps):
base_idx = idx * (n_refs - 1)
# check index overflow
if base_idx + n_refs > n_base_frames:
last_offset = base_idx - (n_base_frames - n_refs)
base_idx = n_base_frames - n_refs
cond_z0 = base_z0[:, :, base_idx : base_idx + n_refs, :, :]
proxy_z0[:, :, model.cond_frames, :, :] = cond_z0
if "c_concat" in cond:
c_cat, text_emb = cond["c_concat"][0], cond["c_crossattn"][0]
episode_idx = idx * unit_length
if last_offset is not None:
episode_idx = episode_idx - last_offset * base_stride
cond_idx = {
"c_concat": [
c_cat[:, :, episode_idx : episode_idx + unit_length, :, :]
],
"c_crossattn": [text_emb],
}
else:
cond_idx = cond
noise_shape_idx = [b, c, unit_length, h, w]
# batch, <samples>, c, t, h, w
batch_idx = batch_ddim_sampling(
model,
cond_idx,
noise_shape_idx,
n_samples,
ddim_steps,
ddim_eta,
cfg_scale,
temporal_cfg_scale,
mask=cond_mask,
x0=proxy_z0,
**kwargs,
)
if batch_samples is None:
batch_samples = batch_idx
else:
# b,s,c,t,h,w
if last_offset is None:
batch_samples = torch.cat(
[batch_samples[:, :, :, :-1, :, :], batch_idx], dim=3
)
else:
batch_samples = torch.cat(
[
batch_samples[:, :, :, :-1, :, :],
batch_idx[:, :, :, last_offset * base_stride :, :, :],
],
dim=3,
)
return batch_samples
def get_filelist(data_dir, ext="*"):
file_list = glob.glob(os.path.join(data_dir, "*.%s" % ext))
file_list.sort()
return file_list
def get_dirlist(path):
list = []
if os.path.exists(path):
files = os.listdir(path)
for file in files:
m = os.path.join(path, file)
if os.path.isdir(m):
list.append(m)
list.sort()
return list
def load_model_checkpoint(model, ckpt, adapter_ckpt=None):
def load_checkpoint(model, ckpt, full_strict):
state_dict = torch.load(ckpt, map_location="cpu", weights_only=True)
try:
# deepspeed
new_pl_sd = OrderedDict()
for key in state_dict["module"].keys():
new_pl_sd[key[16:]] = state_dict["module"][key]
model.load_state_dict(new_pl_sd, strict=full_strict)
except:
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=full_strict)
return model
if adapter_ckpt:
# main model
load_checkpoint(model, ckpt, full_strict=False)
print(">>> model checkpoint loaded.")
# adapter
state_dict = torch.load(adapter_ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
model.adapter.load_state_dict(state_dict, strict=True)
print(">>> adapter checkpoint loaded.")
else:
load_checkpoint(model, ckpt, full_strict=False)
print(">>> model checkpoint loaded.")
return model
def load_prompts(prompt_file):
f = open(prompt_file, "r")
prompt_list = []
for idx, line in enumerate(f.readlines()):
l = line.strip()
if len(l) != 0:
prompt_list.append(l)
f.close()
return prompt_list
def load_camera_poses(filepath_list, video_frames=16):
pose_list = []
for filepath in filepath_list:
with open(filepath, "r") as f:
pose = json.load(f)
pose = np.array(pose) # [t, 12]
pose = torch.tensor(pose).float() # [t, 12]
assert (
pose.shape[0] == video_frames
), f"conditional pose frames Not matching the target frames [{video_frames}]."
pose_list.append(pose)
batch_poses = torch.stack(pose_list, dim=0)
# shape [b,t,12,1]
return batch_poses[..., None]
def save_videos(
batch_tensors: torch.Tensor, save_dir: str, filenames: list[str], fps: int = 10
):
# b,samples,t,c,h,w
n_samples = batch_tensors.shape[1]
for idx, vid_tensor in enumerate(batch_tensors):
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(1, 0, 2, 3, 4) # t,n,c,h,w
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n_samples))
for framesheet in video
] # [3, 1*h, n*w]
# stack in temporal dim [t, 3, n*h, w]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
savepath = os.path.join(save_dir, f"{filenames[idx]}.mp4")
torchvision.io.write_video(
savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"}
)
import torch
import math
def slerp(R1, R2, alpha):
"""
Perform Spherical Linear Interpolation (SLERP) between two rotation matrices.
R1, R2: (3x3) rotation matrices.
alpha: interpolation factor, ranging from 0 to 1.
"""
# Convert the rotation matrices to quaternions
def rotation_matrix_to_quaternion(R):
w = torch.sqrt(1.0 + R[0, 0] + R[1, 1] + R[2, 2]) / 2.0
w4 = 4.0 * w
x = (R[2, 1] - R[1, 2]) / w4
y = (R[0, 2] - R[2, 0]) / w4
z = (R[1, 0] - R[0, 1]) / w4
return torch.tensor([w, x, y, z]).float()
def quaternion_to_rotation_matrix(q):
w, x, y, z = q
return torch.tensor(
[
[
1 - 2 * y * y - 2 * z * z,
2 * x * y - 2 * w * z,
2 * x * z + 2 * w * y,
],
[
2 * x * y + 2 * w * z,
1 - 2 * x * x - 2 * z * z,
2 * y * z - 2 * w * x,
],
[
2 * x * z - 2 * w * y,
2 * y * z + 2 * w * x,
1 - 2 * x * x - 2 * y * y,
],
]
).float()
q1 = rotation_matrix_to_quaternion(R1)
q2 = rotation_matrix_to_quaternion(R2)
# Dot product of the quaternions
dot = torch.dot(q1, q2)
# If the dot product is negative, negate one quaternion to ensure the shortest path is taken
if dot < 0.0:
q2 = -q2
dot = -dot
# SLERP formula
if (
dot > 0.9995
): # If the quaternions are nearly identical, use linear interpolation
q_interp = (1 - alpha) * q1 + alpha * q2
else:
theta_0 = torch.acos(dot) # Angle between q1 and q2
sin_theta_0 = torch.sin(theta_0)
theta = theta_0 * alpha # Angle between q1 and interpolated quaternion
sin_theta = torch.sin(theta)
s1 = torch.sin((1 - alpha) * theta_0) / sin_theta_0
s2 = sin_theta / sin_theta_0
q_interp = s1 * q1 + s2 * q2
# Convert the interpolated quaternion back to a rotation matrix
R_interp = quaternion_to_rotation_matrix(q_interp)
return R_interp
def interpolate_camera_poses(pose1, pose2, num_steps):
"""
Interpolate between two camera poses (3x4 matrices) over a number of steps.
pose1, pose2: (3x4) camera pose matrices (R|t), where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
num_steps: number of interpolation steps.
Returns:
A list of interpolated poses as (3x4) matrices.
"""
R1, t1 = pose1[:, :3], pose1[:, 3]
R2, t2 = pose2[:, :3], pose2[:, 3]
interpolated_poses = []
for i in range(num_steps):
alpha = i / (num_steps - 1) # Interpolation factor ranging from 0 to 1
# Interpolate rotation using SLERP
R_interp = slerp(R1, R2, alpha)
# Interpolate translation using linear interpolation (LERP)
t_interp = (1 - alpha) * t1 + alpha * t2
# Combine interpolated rotation and translation into a (3x4) pose matrix
pose_interp = torch.cat([R_interp, t_interp.unsqueeze(1)], dim=1)
interpolated_poses.append(pose_interp)
return interpolated_poses
def rotation_matrix_from_xyz_angles(x_angle, y_angle, z_angle):
"""
Compute the rotation matrix from given x, y, z angles (in radians).
x_angle: Rotation around the x-axis (pitch).
y_angle: Rotation around the y-axis (yaw).
z_angle: Rotation around the z-axis (roll).
Returns:
A 3x3 rotation matrix.
"""
# Rotation matrices around each axis
Rx = torch.tensor(
[
[1, 0, 0],
[0, torch.cos(x_angle), -torch.sin(x_angle)],
[0, torch.sin(x_angle), torch.cos(x_angle)],
]
).float()
Ry = torch.tensor(
[
[torch.cos(y_angle), 0, torch.sin(y_angle)],
[0, 1, 0],
[-torch.sin(y_angle), 0, torch.cos(y_angle)],
]
).float()
Rz = torch.tensor(
[
[torch.cos(z_angle), -torch.sin(z_angle), 0],
[torch.sin(z_angle), torch.cos(z_angle), 0],
[0, 0, 1],
]
).float()
# Combined rotation matrix R = Rz * Ry * Rx
R_combined = Rz @ Ry @ Rx
return R_combined.float()
def move_pose(pose1, x_angle, y_angle, z_angle, translation):
"""
Calculate the second camera pose based on the first pose and given rotations (x, y, z) and translation.
pose1: The first camera pose (3x4 matrix).
x_angle, y_angle, z_angle: Rotation angles around the x, y, and z axes, in radians.
translation: Translation vector (3,).
Returns:
pose2: The second camera pose as a (3x4) matrix.
"""
# Extract the rotation (R1) and translation (t1) from the first pose
R1 = pose1[:, :3]
t1 = pose1[:, 3]
# Calculate the new rotation matrix from the given angles
R_delta = rotation_matrix_from_xyz_angles(x_angle, y_angle, z_angle)
# New rotation = R1 * R_delta
R2 = R1 @ R_delta
# New translation = t1 + translation
t2 = t1 + translation
# Combine R2 and t2 into the new pose (3x4 matrix)
pose2 = torch.cat([R2, t2.unsqueeze(1)], dim=1)
return pose2
def deg2rad(degrees):
"""Convert degrees to radians."""
return degrees * math.pi / 180.0
def generate_spherical_trajectory(end_angles, radius=1.0, num_steps=36):
"""
Generate a camera-to-world (C2W) trajectory interpolating angles on a sphere.
Args:
end_angles (tuple): The endpoint rotation angles in degrees (x, y, z).
(start is assumed to be (0, 0, 0)).
radius (float): Radius of the sphere.
num_steps (int): Number of steps in the trajectory.
Returns:
torch.Tensor: A tensor of shape [num_steps, 3, 4] with the C2W transformations.
"""
# Convert angles to radians
end_angles_rad = torch.tensor(
[deg2rad(angle) for angle in end_angles], dtype=torch.float32
)
# Interpolate angles linearly
interpolated_angles = (
torch.linspace(0, 1, num_steps).view(-1, 1) * end_angles_rad
) # Shape: [num_steps, 3]
poses = []
for angles in interpolated_angles:
# Extract interpolated angles
x_angle, y_angle = angles
# Compute camera position on the sphere
x = radius * math.sin(y_angle) * math.cos(x_angle)
y = radius * math.sin(x_angle)
z = radius * math.cos(y_angle) * math.cos(x_angle)
cam_position = torch.tensor([x, y, z], dtype=torch.float32)
# Camera's forward direction (looking at the origin)
look_at_dir = -cam_position / torch.norm(cam_position)
# Define the "up" vector
up = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32)
# Compute the right vector
right = torch.cross(up, look_at_dir)
right = right / torch.norm(right)
# Recompute the orthogonal up vector
up = torch.cross(look_at_dir, right)
# Build the rotation matrix
rotation_matrix = torch.stack([right, up, look_at_dir], dim=0) # [3, 3]
# Combine the rotation matrix with the translation (camera position)
c2w = torch.cat([rotation_matrix, cam_position.view(3, 1)], dim=1) # [3, 4]
# Append the pose
poses.append(c2w)
return poses
import torch
def process_inference_batch(cfg_scale, batch, model, with_uncondition_extra=False):
for k in batch.keys():
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(model.device, dtype=model.dtype)
z, cond, x_rec = model.get_batch_input(
batch,
random_drop_training_conditions=False,
return_reconstructed_target_images=True,
)
# batch_size = x_rec.shape[0]
# Get unconditioned embedding for classifier-free guidance sampling
if cfg_scale != 1.0:
uc = model.get_unconditional_dict_for_sampling(batch, cond, x_rec)
else:
uc = None
if with_uncondition_extra:
uc_extra = model.get_unconditional_dict_for_sampling(
batch, cond, x_rec, is_extra=True
)
return cond, uc, uc_extra, x_rec
else:
return cond, uc, x_rec
from utils.utils import instantiate_from_config
import os
import sys
from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset
os.chdir(sys.path[0])
sys.path.append("..")
def t_range(name, tensor):
print(
f"{name}: shape={tensor.shape}, max={torch.max(tensor)}, min={torch.min(tensor)}."
)
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
train_img=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False,
test_max_n_samples=None,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
# train image dataset
if train_img is not None:
img_data = instantiate_from_config(train_img)
self.img_loader = img_data.train_dataloader()
else:
self.img_loader = None
self.wrap = wrap
self.test_max_n_samples = test_max_n_samples
self.collate_fn = None
def prepare_data(self):
# for data_cfg in self.dataset_configs.values():
# instantiate_from_config(data_cfg)
pass
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs
)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = False
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
loader = DataLoader(
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False if is_iterable_dataset else True,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
)
if self.img_loader is not None:
return {"loader_video": loader, "loader_img": self.img_loader}
else:
return loader
def _val_dataloader(self, shuffle=False):
init_fn = None
return DataLoader(
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
collate_fn=self.collate_fn,
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = False
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
if self.test_max_n_samples is not None:
dataset = torch.utils.data.Subset(
self.datasets["test"], list(range(self.test_max_n_samples))
)
else:
dataset = self.datasets["test"]
return DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
collate_fn=self.collate_fn,
)
def _predict_dataloader(self, shuffle=False):
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
)
pytorch_lightning
deepspeed
taming-transformers
scipy
einops
kornia
open_clip_torch
openai-clip
xformers
timm
av
gradio
\ No newline at end of file
FLAG_RUN_DEBUG = False
PATH_DIR_DEBUG = "./debug/"
from utils.utils import instantiate_from_config
import torch
import copy
from omegaconf import OmegaConf
import logging
main_logger = logging.getLogger("main_logger")
def expand_conv_kernel(pretrained_dict):
"""expand 2d conv parameters from 4D -> 5D"""
for k, v in pretrained_dict.items():
if v.dim() == 4 and not k.startswith("first_stage_model"):
v = v.unsqueeze(2)
pretrained_dict[k] = v
return pretrained_dict
def print_state_dict(state_dict):
print("====== Dumping State Dict ======")
for k, v in state_dict.items():
print(k, v.shape)
def load_from_pretrainedSD_checkpoint(
model,
pretained_ckpt,
expand_to_3d=True,
adapt_keyname=False,
echo_empty_params=False,
):
sd_state_dict = torch.load(pretained_ckpt, map_location="cpu")
if "state_dict" in list(sd_state_dict.keys()):
sd_state_dict = sd_state_dict["state_dict"]
model_state_dict = model.state_dict()
# delete ema_weights just for <precise param counting>
for k in list(sd_state_dict.keys()):
if k.startswith("model_ema"):
del sd_state_dict[k]
main_logger.info(
f"Num of model params of Source:{len(sd_state_dict.keys())} VS. Target:{len(model_state_dict.keys())}"
)
# print_state_dict(model_state_dict)
# print_state_dict(sd_state_dict)
if adapt_keyname:
# adapting to standard 2d network: modify the key name because of the add of temporal-attention
mapping_dict = {
"middle_block.2": "middle_block.3",
"output_blocks.5.2": "output_blocks.5.3",
"output_blocks.8.2": "output_blocks.8.3",
}
cnt = 0
for k in list(sd_state_dict.keys()):
for src_word, dst_word in mapping_dict.items():
if src_word in k:
new_key = k.replace(src_word, dst_word)
sd_state_dict[new_key] = sd_state_dict[k]
del sd_state_dict[k]
cnt += 1
main_logger.info(f"[renamed {cnt} Source keys to match Target model]")
pretrained_dict = {
k: v for k, v in sd_state_dict.items() if k in model_state_dict
} # drop extra keys
empty_paras = [
k for k, v in model_state_dict.items() if k not in pretrained_dict
] # log no pretrained keys
assert len(empty_paras) + len(pretrained_dict.keys()) == len(
model_state_dict.keys()
)
if expand_to_3d:
# adapting to 2d inflated network
pretrained_dict = expand_conv_kernel(pretrained_dict)
# overwrite entries in the existing state dict
model_state_dict.update(pretrained_dict)
# load the new state dict
try:
model.load_state_dict(model_state_dict)
except:
skipped = []
model_dict_ori = model.state_dict()
for n, p in model_state_dict.items():
if p.shape != model_dict_ori[n].shape:
# skip by using original empty paras
model_state_dict[n] = model_dict_ori[n]
main_logger.info(
f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_state_dict[n].shape} in current model"
)
skipped.append(n)
main_logger.info(
f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!"
)
model.load_state_dict(model_state_dict)
empty_paras += skipped
# only count Unet part of depth estimation model
unet_empty_paras = [
name for name in empty_paras if name.startswith("model.diffusion_model")
]
main_logger.info(
f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)} [Unet:{len(unet_empty_paras)}]"
)
if echo_empty_params:
print("Printing empty parameters:")
for k in empty_paras:
print(k)
return model, empty_paras
# Below: written by Yingqing --------------------------------------------------------
def load_model_from_config(config, ckpt, verbose=False):
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
main_logger.info("missing keys:")
main_logger.info(m)
if len(u) > 0 and verbose:
main_logger.info("unexpected keys:")
main_logger.info(u)
model.eval()
return model
def init_and_load_ldm_model(config_path, ckpt_path, device=None):
assert config_path.endswith(".yaml"), f"config_path = {config_path}"
assert ckpt_path.endswith(".ckpt"), f"ckpt_path = {ckpt_path}"
config = OmegaConf.load(config_path)
model = load_model_from_config(config, ckpt_path)
if device is not None:
model = model.to(device)
return model
def load_img_model_to_video_model(
model,
device=None,
expand_to_3d=True,
adapt_keyname=False,
config_path="configs/latent-diffusion/txt2img-1p4B-eval.yaml",
ckpt_path="models/ldm/text2img-large/model.ckpt",
):
pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device)
model, empty_paras = load_partial_weights(
model,
pretrained_ldm.state_dict(),
expand_to_3d=expand_to_3d,
adapt_keyname=adapt_keyname,
)
return model, empty_paras
def load_partial_weights(
model, pretrained_dict, expand_to_3d=True, adapt_keyname=False
):
model2 = copy.deepcopy(model)
model_dict = model.state_dict()
model_dict_ori = copy.deepcopy(model_dict)
main_logger.info(f"[Load pretrained LDM weights]")
main_logger.info(
f"Num of parameters of source model:{len(pretrained_dict.keys())} VS. target model:{len(model_dict.keys())}"
)
if adapt_keyname:
# adapting to menghan's standard 2d network: modify the key name because of the add of temporal-attention
mapping_dict = {
"middle_block.2": "middle_block.3",
"output_blocks.5.2": "output_blocks.5.3",
"output_blocks.8.2": "output_blocks.8.3",
}
cnt = 0
newpretrained_dict = copy.deepcopy(pretrained_dict)
for k, v in newpretrained_dict.items():
for src_word, dst_word in mapping_dict.items():
if src_word in k:
new_key = k.replace(src_word, dst_word)
pretrained_dict[new_key] = v
pretrained_dict.pop(k)
cnt += 1
main_logger.info(f"--renamed {cnt} source keys to match target model.")
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict
} # drop extra keys
empty_paras = [
k for k, v in model_dict.items() if k not in pretrained_dict
] # log no pretrained keys
main_logger.info(
f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)}"
)
# disable info
# main_logger.info(f'Empty parameters: {empty_paras} ')
assert len(empty_paras) + len(pretrained_dict.keys()) == len(model_dict.keys())
if expand_to_3d:
# adapting to yingqing's 2d inflation network
pretrained_dict = expand_conv_kernel(pretrained_dict)
# overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# load the new state dict
try:
model2.load_state_dict(model_dict)
except:
# if parameter size mismatch, skip them
skipped = []
for n, p in model_dict.items():
if p.shape != model_dict_ori[n].shape:
# skip by using original empty paras
model_dict[n] = model_dict_ori[n]
main_logger.info(
f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_dict[n].shape} in current model"
)
skipped.append(n)
main_logger.info(
f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!"
)
model2.load_state_dict(model_dict)
empty_paras += skipped
main_logger.info(f"Empty parameters: {len(empty_paras)} ")
main_logger.info(f"Finished.")
return model2, empty_paras
def load_autoencoder(model, config_path=None, ckpt_path=None, device=None):
if config_path is None:
config_path = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
if ckpt_path is None:
ckpt_path = "models/ldm/text2img-large/model.ckpt"
pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device)
autoencoder_dict = {}
for n, p in pretrained_ldm.state_dict().items():
if n.startswith("first_stage_model"):
autoencoder_dict[n] = p
model_dict = model.state_dict()
model_dict.update(autoencoder_dict)
main_logger.info(f"Load [{len(autoencoder_dict)}] autoencoder parameters!")
model.load_state_dict(model_dict)
return model
import numpy as np
import torch
import torch.optim as optim
def build_LR_scheduler(
optimizer, scheduler_name, lr_decay_ratio, max_epochs, start_epoch=0
):
# print("-LR scheduler:%s"%scheduler_name)
if scheduler_name == "LambdaLR":
decay_ratio = lr_decay_ratio
decay_epochs = max_epochs
def polynomial_decay(epoch):
return (
1 + (decay_ratio - 1) * ((epoch + start_epoch) / decay_epochs)
if (epoch + start_epoch) < decay_epochs
else decay_ratio
)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=polynomial_decay
)
elif scheduler_name == "CosineAnnealingLR":
last_epoch = -1 if start_epoch == 0 else start_epoch
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max_epochs, last_epoch=last_epoch
)
elif scheduler_name == "ReduceLROnPlateau":
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, threshold=0.01, patience=5
)
else:
raise NotImplementedError
return lr_scheduler
class LambdaLRScheduler:
# target: torch.optim.lr_scheduler.LambdaLR
def __init__(self, start_step, final_decay_ratio, decay_steps):
self.final_decay_ratio = final_decay_ratio
self.decay_steps = decay_steps
self.start_step = start_step
def schedule(self, step):
if step + self.start_step < self.decay_steps:
return 1.0 + (self.final_decay_ratio - 1) * (
(step + self.start_step) / self.decay_steps
)
else:
return self.final_decay_ratio
def __call__(self, step):
return self.scheduler(step)
class CosineAnnealingLRScheduler:
# target: torch.optim.lr_scheduler.CosineAnnealingLR
def __init__(self, start_step, decay_steps):
self.decay_steps = decay_steps
self.start_step = start_step
def __call__(self, step):
pass
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi)
)
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi)
)
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment