Commit 5e2c95b7 authored by wuxk1's avatar wuxk1
Browse files

init commit for comui

parents
Pipeline #3174 failed with stages
in 0 seconds
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
import torch
import torch.nn as nn
from torch import Tensor
import logging
try:
from torchaudio.transforms import MelScale
except:
logging.warning("torchaudio missing, ACE model will be broken")
import comfy.model_management
class LinearSpectrogram(nn.Module):
def __init__(
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.mode = mode
self.register_buffer("window", torch.hann_window(win_length))
def forward(self, y: Tensor) -> Tensor:
if y.ndim == 3:
y = y.squeeze(1)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
(self.win_length - self.hop_length) // 2,
(self.win_length - self.hop_length + 1) // 2,
),
mode="reflect",
).squeeze(1)
dtype = y.dtype
spec = torch.stft(
y.float(),
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
if self.mode == "pow2_sqrt":
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = spec.to(dtype)
return spec
class LogMelSpectrogram(nn.Module):
def __init__(
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max or sample_rate // 2
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
"slaney",
"slaney",
)
def compress(self, x: Tensor) -> Tensor:
return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: Tensor) -> Tensor:
return torch.exp(x)
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
linear = self.spectrogram(x)
x = self.mel_scale(linear)
x = self.compress(x)
# print(x.shape)
if return_linear:
return x, self.compress(linear)
return x
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
import torch
from torch import nn
from functools import partial
from math import prod
from typing import Callable, Tuple, List
import numpy as np
import torch.nn.functional as F
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
from .music_log_mel import LogMelSpectrogram
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
""" # noqa: E501
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
""" # noqa: E501
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
return x
class ConvNeXtBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
kernel_size (int): Kernel size for depthwise conv. Default: 7.
dilation (int): Dilation for depthwise conv. Default: 1.
""" # noqa: E501
def __init__(
self,
dim: int,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-6,
mlp_ratio: float = 4.0,
kernel_size: int = 7,
dilation: int = 1,
):
super().__init__()
self.dwconv = ops.Conv1d(
dim,
dim,
kernel_size=kernel_size,
padding=int(dilation * (kernel_size - 1) / 2),
groups=dim,
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = ops.Linear(
dim, int(mlp_ratio * dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
self.gamma = (
nn.Parameter(torch.empty((dim)), requires_grad=False)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x, apply_residual: bool = True):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
x = self.drop_path(x)
if apply_residual:
x = input + x
return x
class ParallelConvNeXtBlock(nn.Module):
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
super().__init__()
self.blocks = nn.ModuleList(
[
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
for kernel_size in kernel_sizes
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.stack(
[block(x, apply_residual=False) for block in self.blocks] + [x],
dim=1,
).sum(dim=1)
class ConvNeXtEncoder(nn.Module):
def __init__(
self,
input_channels=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.0,
layer_scale_init_value=1e-6,
kernel_sizes: Tuple[int] = (7,),
):
super().__init__()
assert len(depths) == len(dims)
self.channel_layers = nn.ModuleList()
stem = nn.Sequential(
ops.Conv1d(
input_channels,
dims[0],
kernel_size=7,
padding=3,
padding_mode="replicate",
),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.channel_layers.append(stem)
for i in range(len(depths) - 1):
mid_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
)
self.channel_layers.append(mid_layer)
block_fn = (
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
if len(kernel_sizes) == 1
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
)
self.stages = nn.ModuleList()
drop_path_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(len(depths)):
stage = nn.Sequential(
*[
block_fn(
dim=dims[i],
drop_path=drop_path_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for channel_layer, stage in zip(self.channel_layers, self.stages):
x = channel_layer(x)
x = stage(x)
return self.norm(x)
def get_padding(kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.silu(x)
xt = c1(xt)
xt = F.silu(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for conv in self.convs1:
remove_weight_norm(conv)
for conv in self.convs2:
remove_weight_norm(conv)
class HiFiGANGenerator(nn.Module):
def __init__(
self,
*,
hop_length: int = 512,
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 128,
upsample_initial_channel: int = 512,
use_template: bool = True,
pre_conv_kernel_size: int = 7,
post_conv_kernel_size: int = 7,
post_activation: Callable = partial(nn.SiLU, inplace=True),
):
super().__init__()
assert (
prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
num_mels,
upsample_initial_channel,
pre_conv_kernel_size,
1,
padding=get_padding(pre_conv_kernel_size),
)
)
self.num_upsamples = len(upsample_rates)
self.num_kernels = len(resblock_kernel_sizes)
self.noise_convs = nn.ModuleList()
self.use_template = use_template
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
torch.nn.utils.parametrizations.weight_norm(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
if not use_template:
continue
if i + 1 < len(upsample_rates):
stride_f0 = np.prod(upsample_rates[i + 1:])
self.noise_convs.append(
ops.Conv1d(
1,
c_cur,
kernel_size=stride_f0 * 2,
stride=stride_f0,
padding=stride_f0 // 2,
)
)
else:
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(ResBlock1(ch, k, d))
self.activation_post = post_activation()
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
ch,
1,
post_conv_kernel_size,
1,
padding=get_padding(post_conv_kernel_size),
)
)
def forward(self, x, template=None):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.silu(x, inplace=True)
x = self.ups[i](x)
if self.use_template:
x = x + self.noise_convs[i](template)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for up in self.ups:
remove_weight_norm(up)
for block in self.resblocks:
block.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class ADaMoSHiFiGANV1(nn.Module):
def __init__(
self,
input_channels: int = 128,
depths: List[int] = [3, 3, 9, 3],
dims: List[int] = [128, 256, 384, 512],
drop_path_rate: float = 0.0,
kernel_sizes: Tuple[int] = (7,),
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 512,
upsample_initial_channel: int = 1024,
use_template: bool = False,
pre_conv_kernel_size: int = 13,
post_conv_kernel_size: int = 13,
sampling_rate: int = 44100,
n_fft: int = 2048,
win_length: int = 2048,
hop_length: int = 512,
f_min: int = 40,
f_max: int = 16000,
n_mels: int = 128,
):
super().__init__()
self.backbone = ConvNeXtEncoder(
input_channels=input_channels,
depths=depths,
dims=dims,
drop_path_rate=drop_path_rate,
kernel_sizes=kernel_sizes,
)
self.head = HiFiGANGenerator(
hop_length=hop_length,
upsample_rates=upsample_rates,
upsample_kernel_sizes=upsample_kernel_sizes,
resblock_kernel_sizes=resblock_kernel_sizes,
resblock_dilation_sizes=resblock_dilation_sizes,
num_mels=num_mels,
upsample_initial_channel=upsample_initial_channel,
use_template=use_template,
pre_conv_kernel_size=pre_conv_kernel_size,
post_conv_kernel_size=post_conv_kernel_size,
)
self.sampling_rate = sampling_rate
self.mel_transform = LogMelSpectrogram(
sample_rate=sampling_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
)
self.eval()
@torch.no_grad()
def decode(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y
@torch.no_grad()
def encode(self, x):
return self.mel_transform(x)
def forward(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
import torch
from torch import nn
from typing import Literal
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale) + 1e-4
var = stdev * stdev
logvar = torch.log(var)
latents = torch.randn_like(mean) * stdev + mean
kl = (mean * mean + var - logvar - 1).sum(1).mean()
return latents, kl
class VAEBottleneck(nn.Module):
def __init__(self):
super().__init__()
self.is_discrete = False
def encode(self, x, return_info=False, **kwargs):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["kl"] = kl
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
# self.alpha.requires_grad = alpha_trainable
# self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
def WNConv1d(*args, **kwargs):
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":
act = torch.nn.ELU()
elif activation == "snake":
act = SnakeBeta(channels)
elif activation == "none":
act = torch.nn.Identity()
else:
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act) # noqa: F821 Activation1d is not defined
return act
class ResidualUnit(nn.Module):
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
super().__init__()
self.dilation = dilation
padding = (dilation * (7-1)) // 2
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=7, dilation=dilation, padding=padding),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=out_channels, out_channels=out_channels,
kernel_size=1)
)
def forward(self, x):
res = x
#x = checkpoint(self.layers, x)
x = self.layers(x)
return x + res
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
super().__init__()
self.layers = nn.Sequential(
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=9, use_snake=use_snake),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
)
def forward(self, x):
return self.layers(x)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
super().__init__()
if use_nearest_upsample:
upsample_layer = nn.Sequential(
nn.Upsample(scale_factor=stride, mode="nearest"),
WNConv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride,
stride=1,
bias=False,
padding='same')
)
else:
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
upsample_layer,
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=9, use_snake=use_snake),
)
def forward(self, x):
return self.layers(x)
class OobleckEncoder(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
]
for i in range(self.depth-1):
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class OobleckDecoder(nn.Module):
def __init__(self,
out_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=True):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
]
for i in range(self.depth-1, 0, -1):
layers += [DecoderBlock(
in_channels=c_mults[i]*channels,
out_channels=c_mults[i-1]*channels,
stride=strides[i-1],
use_snake=use_snake,
antialias_activation=antialias_activation,
use_nearest_upsample=use_nearest_upsample
)
]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
nn.Tanh() if final_tanh else nn.Identity()
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class AudioOobleckVAE(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=64,
c_mults = [1, 2, 4, 8, 16],
strides = [2, 4, 4, 8, 8],
use_snake=True,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=False):
super().__init__()
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
self.bottleneck = VAEBottleneck()
def encode(self, x):
return self.bottleneck.encode(self.encoder(x))
def decode(self, x):
return self.decoder(self.bottleneck.decode(x))
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
from comfy.ldm.modules.attention import optimized_attention
import typing as tp
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
import math
import comfy.ops
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
super().__init__()
assert out_features % 2 == 0
self.weight = nn.Parameter(torch.empty(
[out_features // 2, in_features], dtype=dtype, device=device))
def forward(self, input):
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
return torch.cat([f.cos(), f.sin()], dim=-1)
# norms
class LayerNorm(nn.Module):
def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
"""
super().__init__()
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
if bias:
self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
else:
self.beta = None
def forward(self, x):
beta = self.beta
if beta is not None:
beta = comfy.ops.cast_to_input(beta, x)
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
class GLU(nn.Module):
def __init__(
self,
dim_in,
dim_out,
activation,
use_conv = False,
conv_kernel_size = 3,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.act = activation
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
self.use_conv = use_conv
def forward(self, x):
if self.use_conv:
x = rearrange(x, 'b n d -> b d n')
x = self.proj(x)
x = rearrange(x, 'b d n -> b n d')
else:
x = self.proj(x)
x, gate = x.chunk(2, dim = -1)
return x * self.act(gate)
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.max_seq_len = max_seq_len
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return pos_emb
class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
assert (dim % 2) == 0, 'dimension must be divisible by 2'
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
half_dim = dim // 2
freq_seq = torch.arange(half_dim).float() / half_dim
inv_freq = theta ** -freq_seq
self.register_buffer('inv_freq', inv_freq, persistent = False)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = pos - seq_start_pos[..., None]
emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
return emb * self.scale
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.,
dtype=None,
device=None,
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = scale_base
self.register_buffer('scale', scale)
def forward_from_seq_len(self, seq_len, device, dtype):
# device = self.inv_freq.device
t = torch.arange(seq_len, device=device, dtype=dtype)
return self.forward(t)
def forward(self, t):
# device = self.inv_freq.device
device = t.device
# t = t.to(torch.float32)
t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)
def apply_rotary_pos_emb(t, freqs, scale = 1):
out_dtype = t.dtype
# cast to float32 if necessary for numerical stability
dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
freqs, t = freqs.to(dtype), t.to(dtype)
freqs = freqs[-seq_len:, :]
if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
return torch.cat((t, t_unrotated), dim = -1)
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out = None,
mult = 4,
no_bias = False,
glu = True,
use_conv = False,
conv_kernel_size = 3,
zero_init_output = True,
dtype=None,
device=None,
operations=None,
):
super().__init__()
inner_dim = int(dim * mult)
# Default to SwiGLU
activation = nn.SiLU()
dim_out = dim if dim_out is None else dim_out
if glu:
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
else:
linear_in = nn.Sequential(
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
# # init last linear layer to 0
# if zero_init_output:
# nn.init.zeros_(linear_out.weight)
# if not no_bias:
# nn.init.zeros_(linear_out.bias)
self.ff = nn.Sequential(
linear_in,
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):
return self.ff(x)
class Attention(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
dim_context = None,
causal = False,
zero_init_output=True,
qk_norm = False,
natten_kernel_size = None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.causal = causal
dim_kv = dim_context if dim_context is not None else dim
self.num_heads = dim // dim_heads
self.kv_heads = dim_kv // dim_heads
if dim_context is not None:
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
else:
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
# if zero_init_output:
# nn.init.zeros_(self.to_out.weight)
self.qk_norm = qk_norm
def forward(
self,
x,
context = None,
mask = None,
context_mask = None,
rotary_pos_emb = None,
causal = None
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
kv_input = context if has_context else x
if hasattr(self, 'to_q'):
# Use separate linear projections for q and k/v
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
else:
# Use fused linear projection
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# Normalize q and k for cosine sim attention
if self.qk_norm:
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
if rotary_pos_emb is not None and not has_context:
freqs, _ = rotary_pos_emb
q_dtype = q.dtype
k_dtype = k.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
freqs = freqs.to(torch.float32)
q = apply_rotary_pos_emb(q, freqs)
k = apply_rotary_pos_emb(k, freqs)
q = q.to(q_dtype)
k = k.to(k_dtype)
input_mask = context_mask
if input_mask is None and not has_context:
input_mask = mask
# determine masking
masks = []
if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
masks.append(~input_mask)
# Other masks will be added here later
n = q.shape[-2]
causal = self.causal if causal is None else causal
if n == 1 and causal:
causal = False
if h != kv_h:
# Repeat interleave kv_heads to match q_heads
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
out = optimized_attention(q, k, v, h, skip_reshape=True)
out = self.to_out(out)
if mask is not None:
mask = rearrange(mask, 'b n -> b n 1')
out = out.masked_fill(~mask, 0.)
return out
class ConformerModule(nn.Module):
def __init__(
self,
dim,
norm_kwargs = {},
):
super().__init__()
self.dim = dim
self.in_norm = LayerNorm(dim, **norm_kwargs)
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
self.glu = GLU(dim, dim, nn.SiLU())
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
self.swish = nn.SiLU()
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
x = self.in_norm(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.glu(x)
x = rearrange(x, 'b n d -> b d n')
x = self.depthwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.mid_norm(x)
x = self.swish(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv_2(x)
x = rearrange(x, 'b d n -> b n d')
return x
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
cross_attend = False,
dim_context = None,
global_cond_dim = None,
causal = False,
zero_init_branch_outputs = True,
conformer = False,
layer_ix = -1,
remove_norms = False,
attn_kwargs = {},
ff_kwargs = {},
norm_kwargs = {},
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.cross_attend = cross_attend
self.dim_context = dim_context
self.causal = causal
self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
self.self_attn = Attention(
dim,
dim_heads = dim_heads,
causal = causal,
zero_init_output=zero_init_branch_outputs,
dtype=dtype,
device=device,
operations=operations,
**attn_kwargs
)
if cross_attend:
self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
self.cross_attn = Attention(
dim,
dim_heads = dim_heads,
dim_context=dim_context,
causal = causal,
zero_init_output=zero_init_branch_outputs,
dtype=dtype,
device=device,
operations=operations,
**attn_kwargs
)
self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
self.layer_ix = layer_ix
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
self.global_cond_dim = global_cond_dim
if global_cond_dim is not None:
self.to_scale_shift_gate = nn.Sequential(
nn.SiLU(),
nn.Linear(global_cond_dim, dim * 6, bias=False)
)
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
def forward(
self,
x,
context = None,
global_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
# self-attention with adaLN
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
if self.conformer is not None:
x = x + self.conformer(x)
# feedforward with adaLN
residual = x
x = self.ff_norm(x)
x = x * (1 + scale_ff) + shift_ff
x = self.ff(x)
x = x * torch.sigmoid(1 - gate_ff)
x = x + residual
else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
if self.conformer is not None:
x = x + self.conformer(x)
x = x + self.ff(self.ff_norm(x))
return x
class ContinuousTransformer(nn.Module):
def __init__(
self,
dim,
depth,
*,
dim_in = None,
dim_out = None,
dim_heads = 64,
cross_attend=False,
cond_token_dim=None,
global_cond_dim=None,
causal=False,
rotary_pos_emb=True,
zero_init_branch_outputs=True,
conformer=False,
use_sinusoidal_emb=False,
use_abs_pos_emb=False,
abs_pos_emb_max_length=10000,
dtype=None,
device=None,
operations=None,
**kwargs
):
super().__init__()
self.dim = dim
self.depth = depth
self.causal = causal
self.layers = nn.ModuleList([])
self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
else:
self.rotary_pos_emb = None
self.use_sinusoidal_emb = use_sinusoidal_emb
if use_sinusoidal_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim)
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
for i in range(depth):
self.layers.append(
TransformerBlock(
dim,
dim_heads = dim_heads,
cross_attend = cross_attend,
dim_context = cond_token_dim,
global_cond_dim = global_cond_dim,
causal = causal,
zero_init_branch_outputs = zero_init_branch_outputs,
conformer=conformer,
layer_ix=i,
dtype=dtype,
device=device,
operations=operations,
**kwargs
)
)
def forward(
self,
x,
mask = None,
prepend_embeds = None,
prepend_mask = None,
global_cond = None,
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
info = {
"hidden_states": [],
}
x = self.project_in(x)
if prepend_embeds is not None:
prepend_length, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
x = torch.cat((prepend_embeds, x), dim = -2)
if prepend_mask is not None or mask is not None:
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
mask = torch.cat((prepend_mask, mask), dim = -1)
# Attention layers
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
else:
rotary_pos_emb = None
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
blocks_replace = patches_replace.get("dit", {})
# Iterate over the transformer layers
for i, layer in enumerate(self.layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:
info["hidden_states"].append(x)
x = self.project_out(x)
if return_info:
return x, info
return x
class AudioDiffusionTransformer(nn.Module):
def __init__(self,
io_channels=64,
patch_size=1,
embed_dim=1536,
cond_token_dim=768,
project_cond_tokens=False,
global_cond_dim=1536,
project_global_cond=True,
input_concat_dim=0,
prepend_cond_dim=0,
depth=24,
num_heads=24,
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
audio_model="",
dtype=None,
device=None,
operations=None,
**kwargs):
super().__init__()
self.dtype = dtype
self.cond_token_dim = cond_token_dim
# Timestep embeddings
timestep_features_dim = 256
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
self.to_timestep_embed = nn.Sequential(
operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
)
if cond_token_dim > 0:
# Conditioning tokens
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
self.to_cond_embed = nn.Sequential(
operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
)
else:
cond_embed_dim = 0
if global_cond_dim > 0:
# Global conditioning
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
self.to_global_embed = nn.Sequential(
operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
)
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
)
self.input_concat_dim = input_concat_dim
dim_in = io_channels + self.input_concat_dim
self.patch_size = patch_size
# Transformer
self.transformer_type = transformer_type
self.global_cond_type = global_cond_type
if self.transformer_type == "continuous_transformer":
global_dim = None
if self.global_cond_type == "adaLN":
# The global conditioning is projected to the embed_dim already at this point
global_dim = embed_dim
self.transformer = ContinuousTransformer(
dim=embed_dim,
depth=depth,
dim_heads=embed_dim // num_heads,
dim_in=dim_in * patch_size,
dim_out=io_channels * patch_size,
cross_attend = cond_token_dim > 0,
cond_token_dim = cond_embed_dim,
global_cond_dim=global_dim,
dtype=dtype,
device=device,
operations=operations,
**kwargs
)
else:
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
def _forward(
self,
x,
t,
mask=None,
cross_attn_cond=None,
cross_attn_cond_mask=None,
input_concat_cond=None,
global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
return_info=False,
**kwargs):
if cross_attn_cond is not None:
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
if global_embed is not None:
# Project the global conditioning to the embedding dimension
global_embed = self.to_global_embed(global_embed)
prepend_inputs = None
prepend_mask = None
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_inputs = prepend_cond
if prepend_cond_mask is not None:
prepend_mask = prepend_cond_mask
if input_concat_cond is not None:
# Interpolate input_concat_cond to the same length as x
if input_concat_cond.shape[2] != x.shape[2]:
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
x = torch.cat([x, input_concat_cond], dim=1)
# Get the batch of timestep embeddings
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
if global_embed is not None:
global_embed = global_embed + timestep_embed
else:
global_embed = timestep_embed
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
if self.global_cond_type == "prepend":
if prepend_inputs is None:
# Prepend inputs are just the global embed, and the mask is all ones
prepend_inputs = global_embed.unsqueeze(1)
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
else:
# Prepend inputs are the prepend conditioning + the global embed
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
prepend_length = prepend_inputs.shape[1]
x = self.preprocess_conv(x) + x
x = rearrange(x, "b c t -> b t c")
extra_args = {}
if self.global_cond_type == "adaLN":
extra_args["global_cond"] = global_embed
if self.patch_size > 1:
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
if self.transformer_type == "x-transformers":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
elif self.transformer_type == "continuous_transformer":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
if return_info:
output, info = output
elif self.transformer_type == "mm_transformer":
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
if self.patch_size > 1:
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
output = self.postprocess_conv(output) + output
if return_info:
return output, info
return output
def forward(
self,
x,
timestep,
context=None,
context_mask=None,
input_concat_cond=None,
global_embed=None,
negative_global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
mask=None,
return_info=False,
control=None,
**kwargs):
return self._forward(
x,
timestep,
cross_attn_cond=context,
cross_attn_cond_mask=context_mask,
input_concat_cond=input_concat_cond,
global_embed=global_embed,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
mask=mask,
return_info=return_info,
**kwargs
)
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Union
from einops import rearrange
import math
import comfy.ops
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.empty(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
)
class NumberEmbedder(nn.Module):
def __init__(
self,
features: int,
dim: int = 256,
):
super().__init__()
self.features = features
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
device = next(self.embedding.parameters()).device
x = torch.tensor(x, device=device)
assert isinstance(x, Tensor)
shape = x.shape
x = rearrange(x, "... -> (...)")
embedding = self.embedding(x)
x = embedding.view(*shape, self.features)
return x # type: ignore
class Conditioner(nn.Module):
def __init__(
self,
dim: int,
output_dim: int,
project_out: bool = False
):
super().__init__()
self.dim = dim
self.output_dim = output_dim
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
def forward(self, x):
raise NotImplementedError()
class NumberConditioner(Conditioner):
'''
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
'''
def __init__(self,
output_dim: int,
min_val: float=0,
max_val: float=1
):
super().__init__(output_dim, output_dim)
self.min_val = min_val
self.max_val = max_val
self.embedder = NumberEmbedder(features=output_dim)
def forward(self, floats, device=None):
# Cast the inputs to floats
floats = [float(x) for x in floats]
if device is None:
device = next(self.embedder.parameters()).device
floats = torch.tensor(floats).to(device)
floats = floats.clamp(self.min_val, self.max_val)
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
# Cast floats to same type as embedder
embedder_dtype = next(self.embedder.parameters()).dtype
normalized_floats = normalized_floats.to(embedder_dtype)
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
#AuraFlow MMDiT
#Originally written by the AuraFlow Authors
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.patcher_extension
import comfy.ldm.common_dit
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
n_hidden = int(2 * hidden_dim / 3)
n_hidden = find_multiple(n_hidden, 256)
self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
class MultiHeadLayerNorm(nn.Module):
def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
super().__init__()
self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(
variance + self.variance_epsilon
)
hidden_states = self.weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
class SingleAttention(nn.Module):
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
# this is for cond
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
#@torch.compile()
def forward(self, c):
bsz, seqlen1, _ = c.shape
q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c = self.w1o(output)
return c
class DoubleAttention(nn.Module):
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
# this is for cond
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
# this is for x
self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.q_norm2 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm2 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
#@torch.compile()
def forward(self, c, x):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
cq, ck = self.q_norm1(cq), self.k_norm1(ck)
xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
xq, xk = self.q_norm2(xq), self.k_norm2(xk)
# concat all
q, k, v = (
torch.cat([cq, xq], dim=1),
torch.cat([ck, xk], dim=1),
torch.cat([cv, xv], dim=1),
)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
x = self.w2o(x)
return c, x
class MMDiTBlock(nn.Module):
def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
super().__init__()
self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
if not is_last:
self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
self.modC = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
else:
self.modC = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
)
self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
self.modX = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
self.is_last = is_last
#@torch.compile()
def forward(self, c, x, global_cond, **kwargs):
cres, xres = c, x
cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
self.modC(global_cond).chunk(6, dim=1)
)
c = modulate(self.normC1(c), cshift_msa, cscale_msa)
# xpath
xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
self.modX(global_cond).chunk(6, dim=1)
)
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
c, x = self.attn(c, x)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
c = cres + c
x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
x = xres + x
return c, x
class DiTBlock(nn.Module):
# like MMDiTBlock, but it only has X
def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.modCX = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
def forward(self, cx, global_cond, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
cx = cxres + cx
return cx
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = 1000 * torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
#@torch.compile()
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class MMDiT(nn.Module):
def __init__(
self,
in_channels=4,
out_channels=4,
patch_size=2,
dim=3072,
n_layers=36,
n_double_layers=4,
n_heads=12,
global_conddim=3072,
cond_seq_dim=2048,
max_seq=32 * 32,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
self.cond_seq_linear = operations.Linear(
cond_seq_dim, dim, bias=False, dtype=dtype, device=device
) # linear for something like text sequence.
self.init_x_linear = operations.Linear(
patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
) # init linear for patchified image.
self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
self.double_layers = nn.ModuleList([])
self.single_layers = nn.ModuleList([])
for idx in range(n_double_layers):
self.double_layers.append(
MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
)
for idx in range(n_double_layers, n_layers):
self.single_layers.append(
DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
)
self.final_linear = operations.Linear(
dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
)
self.modF = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
)
self.out_channels = out_channels
self.patch_size = patch_size
self.n_double_layers = n_double_layers
self.n_layers = n_layers
self.h_max = round(max_seq**0.5)
self.w_max = round(max_seq**0.5)
@torch.no_grad()
def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
# extend pe
pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
# now we need to extend this to target_dim. for this we will use interpolation.
# we will use torch.nn.functional.interpolate
pe_as_2d = F.interpolate(
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
)
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
self.h_max, self.w_max = target_dim
def pe_selection_index_based_on_dim(self, h, w):
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
starth = self.h_max // 2 - h_p // 2
endh =starth + h_p
startw = self.w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[
starth:endh, startw:endw
]
return original_pe_indexes.flatten()
def unpatchify(self, x, h, w):
c = self.out_channels
p = self.patch_size
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def patchify(self, x):
B, C, H, W = x.size()
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
x = x.view(
B,
C,
(H + 1) // self.patch_size,
self.patch_size,
(W + 1) // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
def apply_pos_embeds(self, x, h, w):
h = (h + 1) // self.patch_size
w = (w + 1) // self.patch_size
max_dim = max(h, w)
cur_dim = self.h_max
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
if max_dim > cur_dim:
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
cur_dim = max_dim
from_h = (cur_dim - h) // 2
from_w = (cur_dim - w) // 2
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, transformer_options, **kwargs)
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
# print(pe_indexes, pe_indexes.shape)
x = self.init_x_linear(self.patchify(x)) # B, T_x, D
x = self.apply_pos_embeds(x, h, w)
# x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
# x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
# process conditions for MMDiT Blocks
c_seq = context # B, T_c, D_c
t = timestep
c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
global_cond = self.t_embedder(t, x.dtype) # B, D
blocks_replace = patches_replace.get("dit", {})
if len(self.double_layers) > 0:
for i, layer in enumerate(self.double_layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
cx = torch.cat([c, x], dim=1)
for i, layer in enumerate(self.single_layers):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
x = cx[:, c_len:]
fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
x = modulate(x, fshift, fscale)
x = self.final_linear(x)
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
return x
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads)
return self.out_proj(out)
class Attention2D(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
def LayerNorm2d_op(operations):
class LayerNorm2d(operations.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return LayerNorm2d
class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def __init__(self, dim, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
super().__init__()
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
# self.depthwise = SAMBlock(c, num_heads, expansion)
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.channelwise = nn.Sequential(
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
nn.GELU(),
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
nn.Dropout(dropout),
operations.Linear(c * 4, c, dtype=dtype, device=device)
)
def forward(self, x, x_skip=None):
x_res = x
x = self.norm(self.depthwise(x))
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x + x_res
class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.self_attn = self_attn
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
self.kv_mapper = nn.Sequential(
nn.SiLU(),
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
def forward(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
return x
class FeedForwardBlock(nn.Module):
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.channelwise = nn.Sequential(
operations.Linear(c, c * 4, dtype=dtype, device=device),
nn.GELU(),
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
nn.Dropout(dropout),
operations.Linear(c * 4, c, dtype=dtype, device=device)
)
def forward(self, x):
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
super().__init__()
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
def forward(self, x, t):
t = t.chunk(len(self.conds) + 1, dim=1)
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
for i, c in enumerate(self.conds):
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
a, b = a + ac, b + bc
return x * (1 + a) + b
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torchvision
from torch import nn
from .common import LayerNorm2d_op
class CNetResBlock(nn.Module):
def __init__(self, c, dtype=None, device=None, operations=None):
super().__init__()
self.blocks = nn.Sequential(
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c, c, kernel_size=3, padding=1),
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c, c, kernel_size=3, padding=1),
)
def forward(self, x):
return x + self.blocks(x)
class ControlNet(nn.Module):
def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
super().__init__()
if bottleneck_mode is None:
bottleneck_mode = 'effnet'
self.proj_blocks = proj_blocks
if bottleneck_mode == 'effnet':
embd_channels = 1280
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
if c_in != 3:
in_weights = self.backbone[0][0].weight.data
self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
if c_in > 3:
# nn.init.constant_(self.backbone[0][0].weight, 0)
self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
else:
self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
elif bottleneck_mode == 'simple':
embd_channels = c_in
self.backbone = nn.Sequential(
operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
)
elif bottleneck_mode == 'large':
self.backbone = nn.Sequential(
operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
*[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
)
embd_channels = 1280
else:
raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
self.projections = nn.ModuleList()
for _ in range(len(proj_blocks)):
self.projections.append(nn.Sequential(
operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
nn.LeakyReLU(0.2, inplace=True),
operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
))
# nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
self.xl = False
self.input_channels = c_in
self.unshuffle_amount = 8
def forward(self, x):
x = self.backbone(x)
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
for i, idx in enumerate(self.proj_blocks):
proj_outputs[idx] = self.projections[i](x)
return {"input": proj_outputs[::-1]}
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from torch import nn
from torch.autograd import Function
import comfy.ops
ops = comfy.ops.disable_weight_init
class vector_quantize(Function):
@staticmethod
def forward(ctx, x, codebook):
with torch.no_grad():
codebook_sqr = torch.sum(codebook ** 2, dim=1)
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
_, indices = dist.min(dim=1)
ctx.save_for_backward(indices, codebook)
ctx.mark_non_differentiable(indices)
nn = torch.index_select(codebook, 0, indices)
return nn, indices
@staticmethod
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output)
return (grad_inputs, grad_codebook)
class VectorQuantize(nn.Module):
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
"""
Takes an input of variable size (as long as the last dimension matches the embedding size).
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
with the same size as the input, vq and commitment components for the loss as a touple
in the second output and the indices of the quantized vectors in the third:
quantized, (vq_loss, commit_loss), indices
"""
super(VectorQuantize, self).__init__()
self.codebook = nn.Embedding(k, embedding_size)
self.codebook.weight.data.uniform_(-1./k, 1./k)
self.vq = vector_quantize.apply
self.ema_decay = ema_decay
self.ema_loss = ema_loss
if ema_loss:
self.register_buffer('ema_element_count', torch.ones(k))
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
def _laplace_smoothing(self, x, epsilon):
n = torch.sum(x)
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
def _updateEMA(self, z_e_x, indices):
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
def idx2vq(self, idx, dim=-1):
q_idx = self.codebook(idx)
if dim != -1:
q_idx = q_idx.movedim(-1, dim)
return q_idx
def forward(self, x, get_losses=True, dim=-1):
if dim != -1:
x = x.movedim(dim, -1)
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
vq_loss, commit_loss = None, None
if self.ema_loss and self.training:
self._updateEMA(z_e_x.detach(), indices.detach())
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
if get_losses:
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
z_q_x = z_q_x.view(x.shape)
if dim != -1:
z_q_x = z_q_x.movedim(-1, dim)
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
class ResBlock(nn.Module):
def __init__(self, c, c_hidden):
super().__init__()
# depthwise/attention
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1),
ops.Conv2d(c, c, kernel_size=3, groups=c)
)
# channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
ops.Linear(c, c_hidden),
nn.GELU(),
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
def forward(self, x):
mods = self.gammas
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
try:
x = x + self.depthwise(x_temp) * mods[2]
except: #operation not implemented for bf16
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
x = x + self.depthwise[1](x_temp) * mods[2]
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x
class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
super().__init__()
self.c_latent = c_latent
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
# Encoder blocks
self.in_block = nn.Sequential(
nn.PixelUnshuffle(2),
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
)
down_blocks = []
for i in range(levels):
if i > 0:
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = ResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(nn.Sequential(
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
))
self.down_blocks = nn.Sequential(*down_blocks)
self.down_blocks[0]
self.codebook_size = codebook_size
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
# Decoder blocks
up_blocks = [nn.Sequential(
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
)]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1))
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2),
)
def encode(self, x, quantize=False):
x = self.in_block(x)
x = self.down_blocks(x)
if quantize:
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
return qe, x, indices, vq_loss + commit_loss * 0.25
else:
return x
def decode(self, x):
x = self.up_blocks(x)
x = self.out_block(x)
return x
def forward(self, x, quantize=False):
qe, x, _, vq_loss = self.encode(x, quantize)
x = self.decode(qe)
return x, vq_loss
class Discriminator(nn.Module):
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
super().__init__()
d = max(depth - 3, 3)
layers = [
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = c_hidden // (2 ** max((d - i), 0))
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.logits = nn.Sigmoid()
def forward(self, x, cond=None):
x = self.encoder(x)
if cond is not None:
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
x = torch.cat([x, cond], dim=1)
x = self.shuffle(x)
x = self.logits(x)
return x
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import math
import torch
from torch import nn
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
class StageB(nn.Module):
def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
super().__init__()
self.dtype = dtype
self.c_r = c_r
self.t_conds = t_conds
self.c_clip_seq = c_clip_seq
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
if not isinstance(self_attn, list):
self_attn = [self_attn] * len(c_hidden)
# CONDITIONING
self.effnet_mapper = nn.Sequential(
operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
self.pixels_mapper = nn.Sequential(
operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
nn.GELU(),
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == 'C':
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'A':
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'F':
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'T':
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
else:
raise Exception(f'Block type {block_type} not supported')
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(c_hidden)):
if i > 0:
self.down_downscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
))
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(blocks[0][i]):
for block_type in level_config[i]:
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
down_block.append(block)
self.down_blocks.append(down_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[0][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
if i > 0:
self.up_upscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
))
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[1][::-1][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
return emb
def gen_c_embeddings(self, clip):
if len(clip.shape) == 2:
clip = clip.unsqueeze(1)
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
align_corners=True)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
t_cond = kwargs.get(c, torch.zeros_like(r))
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
clip = self.gen_c_embeddings(clip)
# Model Blocks
x = self.embedding(x)
x = x + self.effnet_mapper(
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from torch import nn
import math
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
# from .controlnet import ControlNetDeliverer
class UpDownBlock2d(nn.Module):
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
super().__init__()
assert mode in ['up', 'down']
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
align_corners=True) if enabled else nn.Identity()
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class StageC(nn.Module):
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
dtype=None, device=None, operations=None):
super().__init__()
self.dtype = dtype
self.c_r = c_r
self.t_conds = t_conds
self.c_clip_seq = c_clip_seq
if not isinstance(dropout, list):
dropout = [dropout] * len(c_hidden)
if not isinstance(self_attn, list):
self_attn = [self_attn] * len(c_hidden)
# CONDITIONING
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
)
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == 'C':
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'A':
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'F':
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
elif block_type == 'T':
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
else:
raise Exception(f'Block type {block_type} not supported')
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(c_hidden)):
if i > 0:
self.down_downscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
))
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(blocks[0][i]):
for block_type in level_config[i]:
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
down_block.append(block)
self.down_blocks.append(down_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[0][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(c_hidden))):
if i > 0:
self.up_upscalers.append(nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
))
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(block_repeat[1][::-1][i] - 1):
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
nn.PixelShuffle(patch_size),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
return emb
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
clip_txt = self.clip_txt_mapper(clip_txt)
if len(clip_txt_pooled.shape) == 2:
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
if len(clip_img.shape) == 2:
clip_img = clip_img.unsqueeze(1)
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
if cnet is not None:
next_cnet = cnet.pop()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
align_corners=True).to(x.dtype)
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, ResBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
ResBlock)):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
align_corners=True)
if cnet is not None:
next_cnet = cnet.pop()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
align_corners=True).to(x.dtype)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
t_cond = kwargs.get(c, torch.zeros_like(r))
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
if control is not None:
cnet = control.get("input")
else:
cnet = None
# Model Blocks
x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet)
x = self._up_decode(level_outputs, r_embed, clip, cnet)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
import comfy.ops
ops = comfy.ops.disable_weight_init
# EfficientNet
class EfficientNetEncoder(nn.Module):
def __init__(self, c_latent=16):
super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
self.mapper = nn.Sequential(
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
)
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
def forward(self, x):
x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
o = self.mapper(self.backbone(x))
return o
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__()
self.blocks = nn.Sequential(
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(),
nn.BatchNorm2d(c_hidden),
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden),
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
)
def forward(self, x):
return (self.blocks(x) - 0.5) * 2.0
class StageC_coder(nn.Module):
def __init__(self):
super().__init__()
self.previewer = Previewer()
self.encoder = EfficientNetEncoder()
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.previewer(x)
import torch
from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)
class ChromaModulationOut(ModulationOut):
@classmethod
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
return cls(
shift=tensor[:, offset : offset + 1, :],
scale=tensor[:, offset + 1 : offset + 2, :],
gate=tensor[:, offset + 2 : offset + 3, :],
)
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def forward(self, x: Tensor) -> Tensor:
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
x = self.out_proj(x)
return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
x = self.linear(x)
return x
#Original code can be found on: https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.patcher_extension
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
)
from .layers import (
DoubleStreamBlock,
LastLayer,
SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@dataclass
class ChromaParams:
in_channels: int
out_channels: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
in_dim: int
out_dim: int
hidden_dim: int
n_layers: int
class Chroma(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = ChromaParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.in_dim = params.in_dim
self.out_dim = params.out_dim
self.hidden_dim = params.hidden_dim
self.n_layers = params.n_layers
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
# set as nn identity for now, will overwrite it later.
self.distilled_guidance_layer = Approximator(
in_dim=self.in_dim,
hidden_dim=self.hidden_dim,
out_dim=self.out_dim,
n_layers=self.n_layers,
dtype=dtype, device=device, operations=operations
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
self.skip_mmdit = []
self.skip_dit = []
self.lite = False
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
# This function slices up the modulations tensor which has the following layout:
# single : num_single_blocks * 3 elements
# double_img : num_double_blocks * 6 elements
# double_txt : num_double_blocks * 6 elements
# final : 2 elements
if block_type == "final":
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
single_block_count = self.params.depth_single_blocks
double_block_count = self.params.depth
offset = 3 * idx
if block_type == "single":
return ChromaModulationOut.from_offset(tensor, offset)
# Double block modulations are 6 elements so we double 3 * idx.
offset *= 2
if block_type in {"double_img", "double_txt"}:
# Advance past the single block modulations.
offset += 3 * single_block_count
if block_type == "double_txt":
# Advance past the double block img modulations.
offset += 6 * double_block_count
return (
ChromaModulationOut.from_offset(tensor, offset),
ChromaModulationOut.from_offset(tensor, offset + 3),
)
raise ValueError("Bad block_type")
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
guidance: Tensor = None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
# distilled vector guidance
mod_index_length = 344
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
# guidance = guidance *
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
# get all modulation index
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
# and we need to broadcast timestep and guidance along too
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
mod_vectors = self.distilled_guidance_layer(input_vec)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if i not in self.skip_mmdit:
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
self.get_modulations(mod_vectors, "double_txt", idx=i),
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=double_mod,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
final_mod = self.get_modulations(mod_vectors, "final")
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
import torch
import comfy.rmsnorm
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
padding_mode = "reflect"
pad = ()
for i in range(img.ndim - 2):
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
return torch.nn.functional.pad(img, pad, mode=padding_mode)
rms_norm = comfy.rmsnorm.rms_norm
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import logging
import numpy as np
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
from comfy.ldm.modules.attention import optimized_attention
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I":
return nn.Identity()
elif name == "R":
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else:
raise ValueError(f"Normalization {name} not found")
class BaseAttentionOp(nn.Module):
def __init__(self):
super().__init__()
class Attention(nn.Module):
"""
Generalized attention impl.
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
If `context_dim` is None, self-attention is assumed.
Parameters:
query_dim (int): Dimension of each query vector.
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
heads (int, optional): Number of attention heads. Defaults to 8.
dim_head (int, optional): Dimension of each head. Defaults to 64.
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
Defaults to "SSI".
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
Defaults to 'per_head'. Only support 'per_head'.
Examples:
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
>>> query = torch.randn(10, 128) # Batch size of 10
>>> context = torch.randn(10, 256) # Batch size of 10
>>> output = attn(query, context) # Perform the attention operation
Note:
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
"""
def __init__(
self,
query_dim: int,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
attn_op: Optional[BaseAttentionOp] = None,
qkv_bias: bool = False,
out_bias: bool = False,
qkv_norm: str = "SSI",
qkv_norm_mode: str = "per_head",
backend: str = "transformer_engine",
qkv_format: str = "bshd",
weight_args={},
operations=None,
) -> None:
super().__init__()
self.is_selfattn = context_dim is None # self attention
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.heads = heads
self.dim_head = dim_head
self.qkv_norm_mode = qkv_norm_mode
self.qkv_format = qkv_format
if self.qkv_norm_mode == "per_head":
norm_dim = dim_head
else:
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
self.backend = backend
self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
nn.Dropout(dropout),
)
def cal_qkv(
self, x, context=None, mask=None, rope_emb=None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
del kwargs
"""
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
Before 07/24/2024, these modules normalize across all heads.
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
we support to normalize per head.
To keep the checkpoint copatibility with the previous code,
we keep the nn.Sequential but call the projection and the normalization layers separately.
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
"""
if self.qkv_norm_mode == "per_head":
q = self.to_q[0](x)
context = x if context is None else context
k = self.to_k[0](context)
v = self.to_v[0](context)
q, k, v = map(
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
(q, k, v),
)
else:
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
q = self.to_q[1](q)
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
# apply_rotary_pos_emb inlined
q_shape = q.shape
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
# apply_rotary_pos_emb inlined
k_shape = k.shape
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
return q, k, v
def forward(
self,
x,
context=None,
mask=None,
rope_emb=None,
**kwargs,
):
"""
Args:
x (Tensor): The query tensor of shape [B, Mq, K]
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
class FeedForward(nn.Module):
"""
Transformer FFN with optional gating
Parameters:
d_model (int): Dimensionality of input features.
d_ff (int): Dimensionality of the hidden layer.
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
activation (callable, optional): The activation function applied after the first linear layer.
Defaults to nn.ReLU().
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
Defaults to False.
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
Example:
>>> ff = FeedForward(d_model=512, d_ff=2048)
>>> x = torch.randn(64, 10, 512) # Example input tensor
>>> output = ff(x)
>>> print(output.shape) # Expected shape: (64, 10, 512)
"""
def __init__(
self,
d_model: int,
d_ff: int,
dropout: float = 0.1,
activation=nn.ReLU(),
is_gated: bool = False,
bias: bool = False,
weight_args={},
operations=None,
) -> None:
super().__init__()
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
self.dropout = nn.Dropout(dropout)
self.activation = activation
self.is_gated = is_gated
if is_gated:
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
def forward(self, x: torch.Tensor):
g = self.activation(self.layer1(x))
if self.is_gated:
x = g * self.linear_gate(x)
else:
x = g
assert self.dropout.p == 0.0, "we skip dropout"
return self.layer2(x)
class GPT2FeedForward(FeedForward):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
super().__init__(
d_model=d_model,
d_ff=d_ff,
dropout=dropout,
activation=nn.GELU(),
is_gated=False,
bias=bias,
weight_args=weight_args,
operations=operations,
)
def forward(self, x: torch.Tensor):
assert self.dropout.p == 0.0, "we skip dropout"
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class Timesteps(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.num_channels = num_channels
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - 0.0)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
emb = torch.cat([cos_emb, sin_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
super().__init__()
logging.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
self.activation = nn.SiLU()
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
else:
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(sample)
emb = self.activation(emb)
emb = self.linear_2(emb)
if self.use_adaln_lora:
adaln_lora_B_3D = emb
emb_B_D = sample
else:
emb_B_D = emb
adaln_lora_B_3D = None
return emb_B_D, adaln_lora_B_3D
class FourierFeatures(nn.Module):
"""
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
[B] -> [B, D]
Parameters:
num_channels (int): The number of Fourier features to generate.
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
the variance of the features. Defaults to False.
Example:
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
>>> x = torch.randn(10, 256) # Example input tensor
>>> output = layer(x)
>>> print(output.shape) # Expected shape: (10, 256)
"""
def __init__(self, num_channels, bandwidth=1, normalize=False):
super().__init__()
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
self.gain = np.sqrt(2) if normalize else 1
def forward(self, x, gain: float = 1.0):
"""
Apply the Fourier feature transformation to the input tensor.
Args:
x (torch.Tensor): The input tensor.
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
Returns:
torch.Tensor: The transformed tensor, with Fourier features applied.
"""
in_dtype = x.dtype
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
x = x.cos().mul(self.gain * gain).to(in_dtype)
return x
class PatchEmbed(nn.Module):
"""
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
making it suitable for video and image processing tasks. It supports dividing the input into patches
and embedding each patch into a vector of size `out_channels`.
Parameters:
- spatial_patch_size (int): The size of each spatial patch.
- temporal_patch_size (int): The size of each temporal patch.
- in_channels (int): Number of input channels. Default: 3.
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
"""
def __init__(
self,
spatial_patch_size,
temporal_patch_size,
in_channels=3,
out_channels=768,
bias=True,
weight_args={},
operations=None,
):
super().__init__()
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.proj = nn.Sequential(
Rearrange(
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
r=temporal_patch_size,
m=spatial_patch_size,
n=spatial_patch_size,
),
operations.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
),
)
self.out = nn.Identity()
def forward(self, x):
"""
Forward pass of the PatchEmbed module.
Parameters:
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
B is the batch size,
C is the number of channels,
T is the temporal dimension,
H is the height, and
W is the width of the input.
Returns:
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
"""
assert x.dim() == 5
_, _, T, H, W = x.shape
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
assert T % self.temporal_patch_size == 0
x = self.proj(x)
return self.out(x)
class FinalLayer(nn.Module):
"""
The final layer of video DiT.
"""
def __init__(
self,
hidden_size,
spatial_patch_size,
temporal_patch_size,
out_channels,
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None,
):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
)
self.hidden_size = hidden_size
self.n_adaln_chunks = 2
self.use_adaln_lora = use_adaln_lora
if use_adaln_lora:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
)
def forward(
self,
x_BT_HW_D,
emb_B_D,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_3D is not None
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
2, dim=1
)
else:
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
B = emb_B_D.shape[0]
T = x_BT_HW_D.shape[0] // B
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
x_BT_HW_D = self.linear(x_BT_HW_D)
return x_BT_HW_D
class VideoAttn(nn.Module):
"""
Implements video attention with optional cross-attention capabilities.
This module processes video features while maintaining their spatio-temporal structure. It can perform
self-attention within the video features or cross-attention with external context features.
Parameters:
x_dim (int): Dimension of input feature vectors
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
num_heads (int): Number of attention heads
bias (bool): Whether to include bias in attention projections. Default: False
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
Input shape:
- x: (T, H, W, B, D) video features
- context (optional): (M, B, D) context features for cross-attention
where:
T: temporal dimension
H: height
W: width
B: batch size
D: feature dimension
M: context sequence length
"""
def __init__(
self,
x_dim: int,
context_dim: Optional[int],
num_heads: int,
bias: bool = False,
qkv_norm_mode: str = "per_head",
x_format: str = "BTHWD",
weight_args={},
operations=None,
) -> None:
super().__init__()
self.x_format = x_format
self.attn = Attention(
x_dim,
context_dim,
num_heads,
x_dim // num_heads,
qkv_bias=bias,
qkv_norm="RRI",
out_bias=bias,
qkv_norm_mode=qkv_norm_mode,
qkv_format="sbhd",
weight_args=weight_args,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for video attention.
Args:
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
where M is the sequence length of the context.
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
rope_emb_L_1_1_D (Optional[Tensor]):
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
Returns:
Tensor: The output tensor with applied attention, maintaining the input shape.
"""
x_T_H_W_B_D = x
context_M_B_D = context
T, H, W, B, D = x_T_H_W_B_D.shape
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
x_THW_B_D = self.attn(
x_THW_B_D,
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
def adaln_norm_state(norm_state, x, scale, shift):
normalized = norm_state(x)
return normalized * (1 + scale) + shift
class DITBuildingBlock(nn.Module):
"""
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
attention and MLP operations with adaptive layer normalization.
Parameters:
block_type (str): Type of block - one of:
- "cross_attn"/"ca": Cross-attention
- "full_attn"/"fa": Full self-attention
- "mlp"/"ff": MLP/feedforward block
x_dim (int): Dimension of input features
context_dim (Optional[int]): Dimension of context features for cross-attention
num_heads (int): Number of attention heads
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
bias (bool): Whether to use bias in layers. Default: False
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
x_format (str): Input tensor format. Default: "BTHWD"
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
"""
def __init__(
self,
block_type: str,
x_dim: int,
context_dim: Optional[int],
num_heads: int,
mlp_ratio: float = 4.0,
bias: bool = False,
mlp_dropout: float = 0.0,
qkv_norm_mode: str = "per_head",
x_format: str = "BTHWD",
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None
) -> None:
block_type = block_type.lower()
super().__init__()
self.x_format = x_format
if block_type in ["cross_attn", "ca"]:
self.block = VideoAttn(
x_dim,
context_dim,
num_heads,
bias=bias,
qkv_norm_mode=qkv_norm_mode,
x_format=self.x_format,
weight_args=weight_args,
operations=operations,
)
elif block_type in ["full_attn", "fa"]:
self.block = VideoAttn(
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
)
elif block_type in ["mlp", "ff"]:
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
else:
raise ValueError(f"Unknown block type: {block_type}")
self.block_type = block_type
self.use_adaln_lora = use_adaln_lora
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.n_adaln_chunks = 3
if use_adaln_lora:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
)
else:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
def forward(
self,
x: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
Args:
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
crossattn_emb (Tensor): Tensor for cross-attention blocks.
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
rope_emb_L_1_1_D (Optional[Tensor]):
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
Returns:
Tensor: The output tensor after processing through the configured block and adaptive normalization.
"""
if self.use_adaln_lora:
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
self.n_adaln_chunks, dim=1
)
else:
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
)
if self.block_type in ["mlp", "ff"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
)
elif self.block_type in ["full_attn", "fa"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
return x
class GeneralDITTransformerBlock(nn.Module):
"""
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
Each block in the sequence is specified by a block configuration string.
Parameters:
x_dim (int): Dimension of input features
context_dim (int): Dimension of context features for cross-attention blocks
num_heads (int): Number of attention heads
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
full-attention, then MLP)
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
x_format (str): Input tensor format. Default: "BTHWD"
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
The block_config string uses "-" to separate block types:
- "ca"/"cross_attn": Cross-attention block
- "fa"/"full_attn": Full self-attention block
- "mlp"/"ff": MLP/feedforward block
Example:
block_config = "ca-fa-mlp" creates a sequence of:
1. Cross-attention block
2. Full self-attention block
3. MLP block
"""
def __init__(
self,
x_dim: int,
context_dim: int,
num_heads: int,
block_config: str,
mlp_ratio: float = 4.0,
x_format: str = "BTHWD",
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
weight_args={},
operations=None
):
super().__init__()
self.blocks = nn.ModuleList()
self.x_format = x_format
for block_type in block_config.split("-"):
self.blocks.append(
DITBuildingBlock(
block_type,
x_dim,
context_dim,
num_heads,
mlp_ratio,
x_format=self.x_format,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
)
def forward(
self,
x: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for block in self.blocks:
x = block(
x,
emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The model definition for 3D layers
Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/
9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889
[MIT License Copyright (c) 2023 Phil Wang]
https://github.com/lucidrains/magvit2-pytorch/blob/
9f49074179c912736e617d61b32be367eb5f993a/LICENSE
"""
import math
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from .patching import (
Patcher,
Patcher3D,
UnPatcher,
UnPatcher3D,
)
from .utils import (
CausalNormalize,
batch2space,
batch2time,
cast_tuple,
is_odd,
nonlinearity,
replication_pad,
space2batch,
time2batch,
)
import comfy.ops
ops = comfy.ops.disable_weight_init
_LEGACY_NUM_GROUPS = 32
class CausalConv3d(nn.Module):
def __init__(
self,
chan_in: int = 1,
chan_out: int = 1,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
pad_mode: str = "constant",
**kwargs,
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop("dilation", 1)
stride = kwargs.pop("stride", 1)
time_stride = kwargs.pop("time_stride", 1)
time_dilation = kwargs.pop("time_dilation", 1)
padding = kwargs.pop("padding", 1)
self.pad_mode = pad_mode
time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride)
self.time_pad = time_pad
self.spatial_pad = (padding, padding, padding, padding)
stride = (time_stride, stride, stride)
dilation = (time_dilation, dilation, dilation)
self.conv3d = ops.Conv3d(
chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs,
)
def _replication_pad(self, x: torch.Tensor) -> torch.Tensor:
x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1)
x = torch.cat([x_prev, x], dim=2)
padding = self.spatial_pad + (0, 0)
return F.pad(x, padding, mode=self.pad_mode, value=0.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._replication_pad(x)
return self.conv3d(x)
class CausalUpsample3d(nn.Module):
def __init__(self, in_channels: int) -> None:
super().__init__()
self.conv = CausalConv3d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
if isinstance(time_factor, torch.Tensor):
time_factor = time_factor.item()
x = x.repeat_interleave(int(time_factor), dim=2)
# TODO(freda): Check if this causes temporal inconsistency.
# Shoule reverse the order of the following two ops,
# better perf and better temporal smoothness.
x = self.conv(x)
return x[..., int(time_factor - 1) :, :, :]
class CausalDownsample3d(nn.Module):
def __init__(self, in_channels: int) -> None:
super().__init__()
self.conv = CausalConv3d(
in_channels,
in_channels,
kernel_size=3,
stride=2,
time_stride=2,
padding=0,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad = (0, 1, 0, 1, 0, 0)
x = F.pad(x, pad, mode="constant", value=0)
x = replication_pad(x)
x = self.conv(x)
return x
class CausalHybridUpsample3d(nn.Module):
def __init__(
self,
in_channels: int,
spatial_up: bool = True,
temporal_up: bool = True,
**kwargs,
) -> None:
super().__init__()
self.spatial_up = spatial_up
self.temporal_up = temporal_up
if not self.spatial_up and not self.temporal_up:
return
self.conv1 = CausalConv3d(
in_channels,
in_channels,
kernel_size=(3, 1, 1),
stride=1,
time_stride=1,
padding=0,
)
self.conv2 = CausalConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=1,
time_stride=1,
padding=1,
)
self.conv3 = CausalConv3d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
time_stride=1,
padding=0,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.spatial_up and not self.temporal_up:
return x
# hybrid upsample temporally.
if self.temporal_up:
time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
if isinstance(time_factor, torch.Tensor):
time_factor = time_factor.item()
x = x.repeat_interleave(int(time_factor), dim=2)
x = x[..., int(time_factor - 1) :, :, :]
x = self.conv1(x) + x
# hybrid upsample spatially.
if self.spatial_up:
x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
x = self.conv2(x) + x
# final 1x1x1 conv.
x = self.conv3(x)
return x
class CausalHybridDownsample3d(nn.Module):
def __init__(
self,
in_channels: int,
spatial_down: bool = True,
temporal_down: bool = True,
**kwargs,
) -> None:
super().__init__()
self.spatial_down = spatial_down
self.temporal_down = temporal_down
if not self.spatial_down and not self.temporal_down:
return
self.conv1 = CausalConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=2,
time_stride=1,
padding=0,
)
self.conv2 = CausalConv3d(
in_channels,
in_channels,
kernel_size=(3, 1, 1),
stride=1,
time_stride=2,
padding=0,
)
self.conv3 = CausalConv3d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
time_stride=1,
padding=0,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.spatial_down and not self.temporal_down:
return x
# hybrid downsample spatially.
if self.spatial_down:
pad = (0, 1, 0, 1, 0, 0)
x = F.pad(x, pad, mode="constant", value=0)
x1 = self.conv1(x)
x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2))
x = x1 + x2
# hybrid downsample temporally.
if self.temporal_down:
x = replication_pad(x)
x1 = self.conv2(x)
x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1))
x = x1 + x2
# final 1x1x1 conv.
x = self.conv3(x)
return x
class CausalResnetBlock3d(nn.Module):
def __init__(
self,
*,
in_channels: int,
out_channels: int = None,
dropout: float,
num_groups: int,
) -> None:
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.norm1 = CausalNormalize(in_channels, num_groups=num_groups)
self.conv1 = CausalConv3d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = CausalConv3d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.nin_shortcut = (
CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
x = self.nin_shortcut(x)
return x + h
class CausalResnetBlockFactorized3d(nn.Module):
def __init__(
self,
*,
in_channels: int,
out_channels: int = None,
dropout: float,
num_groups: int,
) -> None:
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.norm1 = CausalNormalize(in_channels, num_groups=1)
self.conv1 = nn.Sequential(
CausalConv3d(
in_channels,
out_channels,
kernel_size=(1, 3, 3),
stride=1,
padding=1,
),
CausalConv3d(
out_channels,
out_channels,
kernel_size=(3, 1, 1),
stride=1,
padding=0,
),
)
self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = nn.Sequential(
CausalConv3d(
out_channels,
out_channels,
kernel_size=(1, 3, 3),
stride=1,
padding=1,
),
CausalConv3d(
out_channels,
out_channels,
kernel_size=(3, 1, 1),
stride=1,
padding=0,
),
)
self.nin_shortcut = (
CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
x = self.nin_shortcut(x)
return x + h
class CausalAttnBlock(nn.Module):
def __init__(self, in_channels: int, num_groups: int) -> None:
super().__init__()
self.norm = CausalNormalize(in_channels, num_groups=num_groups)
self.q = CausalConv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = CausalConv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = CausalConv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = CausalConv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.optimized_attention = vae_attention()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
q, batch_size = time2batch(q)
k, batch_size = time2batch(k)
v, batch_size = time2batch(v)
b, c, h, w = q.shape
h_ = self.optimized_attention(q, k, v)
h_ = batch2time(h_, batch_size)
h_ = self.proj_out(h_)
return x + h_
class CausalTemporalAttnBlock(nn.Module):
def __init__(self, in_channels: int, num_groups: int) -> None:
super().__init__()
self.norm = CausalNormalize(in_channels, num_groups=num_groups)
self.q = CausalConv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = CausalConv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = CausalConv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = CausalConv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
q, batch_size, height = space2batch(q)
k, _, _ = space2batch(k)
v, _, _ = space2batch(v)
bhw, c, t = q.shape
q = q.permute(0, 2, 1) # (bhw, t, c)
k = k.permute(0, 2, 1) # (bhw, t, c)
v = v.permute(0, 2, 1) # (bhw, t, c)
w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t)
w_ = w_ * (int(c) ** (-0.5))
# Apply causal mask
mask = torch.tril(torch.ones_like(w_))
w_ = w_.masked_fill(mask == 0, float("-inf"))
w_ = F.softmax(w_, dim=2)
# attend to values
h_ = torch.bmm(w_, v) # (bhw, t, c)
h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t)
h_ = batch2space(h_, batch_size, height)
h_ = self.proj_out(h_)
return x + h_
class EncoderBase(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
channels_mult: list[int],
num_res_blocks: int,
attn_resolutions: list[int],
dropout: float,
resolution: int,
z_channels: int,
**ignore_kwargs,
) -> None:
super().__init__()
self.num_resolutions = len(channels_mult)
self.num_res_blocks = num_res_blocks
# Patcher.
patch_size = ignore_kwargs.get("patch_size", 1)
self.patcher = Patcher(
patch_size, ignore_kwargs.get("patch_method", "rearrange")
)
in_channels = in_channels * patch_size * patch_size
# downsampling
self.conv_in = CausalConv3d(
in_channels, channels, kernel_size=3, stride=1, padding=1
)
# num of groups for GroupNorm, num_groups=1 for LayerNorm.
num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS)
curr_res = resolution // patch_size
in_ch_mult = (1,) + tuple(channels_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = channels * in_ch_mult[i_level]
block_out = channels * channels_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(
CausalResnetBlock3d(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
num_groups=num_groups,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(CausalAttnBlock(block_in, num_groups=num_groups))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = CausalDownsample3d(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = CausalResnetBlock3d(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
num_groups=num_groups,
)
self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups)
self.mid.block_2 = CausalResnetBlock3d(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
num_groups=num_groups,
)
# end
self.norm_out = CausalNormalize(block_in, num_groups=num_groups)
self.conv_out = CausalConv3d(
block_in, z_channels, kernel_size=3, stride=1, padding=1
)
def patcher3d(self, x: torch.Tensor) -> torch.Tensor:
x, batch_size = time2batch(x)
x = self.patcher(x)
x = batch2time(x, batch_size)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patcher3d(x)
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
else:
# temporal downsample (last level)
time_factor = 1 + 1 * (hs[-1].shape[2] > 1)
if isinstance(time_factor, torch.Tensor):
time_factor = time_factor.item()
hs[-1] = replication_pad(hs[-1])
hs.append(
F.avg_pool3d(
hs[-1],
kernel_size=[time_factor, 1, 1],
stride=[2, 1, 1],
)
)
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class DecoderBase(nn.Module):
def __init__(
self,
out_channels: int,
channels: int,
channels_mult: list[int],
num_res_blocks: int,
attn_resolutions: list[int],
dropout: float,
resolution: int,
z_channels: int,
**ignore_kwargs,
):
super().__init__()
self.num_resolutions = len(channels_mult)
self.num_res_blocks = num_res_blocks
# UnPatcher.
patch_size = ignore_kwargs.get("patch_size", 1)
self.unpatcher = UnPatcher(
patch_size, ignore_kwargs.get("patch_method", "rearrange")
)
out_ch = out_channels * patch_size * patch_size
block_in = channels * channels_mult[self.num_resolutions - 1]
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logging.debug(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
self.conv_in = CausalConv3d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# num of groups for GroupNorm, num_groups=1 for LayerNorm.
num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS)
# middle
self.mid = nn.Module()
self.mid.block_1 = CausalResnetBlock3d(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
num_groups=num_groups,
)
self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups)
self.mid.block_2 = CausalResnetBlock3d(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
num_groups=num_groups,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = channels * channels_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(
CausalResnetBlock3d(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
num_groups=num_groups,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(CausalAttnBlock(block_in, num_groups=num_groups))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = CausalUpsample3d(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = CausalNormalize(block_in, num_groups=num_groups)
self.conv_out = CausalConv3d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor:
x, batch_size = time2batch(x)
x = self.unpatcher(x)
x = batch2time(x, batch_size)
return x
def forward(self, z):
h = self.conv_in(z)
# middle block.
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# decoder blocks.
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
else:
# temporal upsample (last level)
time_factor = 1.0 + 1.0 * (h.shape[2] > 1)
if isinstance(time_factor, torch.Tensor):
time_factor = time_factor.item()
h = h.repeat_interleave(int(time_factor), dim=2)
h = h[..., int(time_factor - 1) :, :, :]
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h = self.unpatcher3d(h)
return h
class EncoderFactorized(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
channels_mult: list[int],
num_res_blocks: int,
attn_resolutions: list[int],
dropout: float,
resolution: int,
z_channels: int,
spatial_compression: int = 8,
temporal_compression: int = 8,
**ignore_kwargs,
) -> None:
super().__init__()
self.num_resolutions = len(channels_mult)
self.num_res_blocks = num_res_blocks
# Patcher.
patch_size = ignore_kwargs.get("patch_size", 1)
self.patcher3d = Patcher3D(
patch_size, ignore_kwargs.get("patch_method", "haar")
)
in_channels = in_channels * patch_size * patch_size * patch_size
# calculate the number of downsample operations
self.num_spatial_downs = int(math.log2(spatial_compression)) - int(
math.log2(patch_size)
)
assert (
self.num_spatial_downs <= self.num_resolutions
), f"Spatially downsample {self.num_resolutions} times at most"
self.num_temporal_downs = int(math.log2(temporal_compression)) - int(
math.log2(patch_size)
)
assert (
self.num_temporal_downs <= self.num_resolutions
), f"Temporally downsample {self.num_resolutions} times at most"
# downsampling
self.conv_in = nn.Sequential(
CausalConv3d(
in_channels,
channels,
kernel_size=(1, 3, 3),
stride=1,
padding=1,
),
CausalConv3d(
channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0
),
)
curr_res = resolution // patch_size
in_ch_mult = (1,) + tuple(channels_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = channels * in_ch_mult[i_level]
block_out = channels * channels_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(
CausalResnetBlockFactorized3d(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
num_groups=1,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(
nn.Sequential(
CausalAttnBlock(block_in, num_groups=1),
CausalTemporalAttnBlock(block_in, num_groups=1),
)
)
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
spatial_down = i_level < self.num_spatial_downs
temporal_down = i_level < self.num_temporal_downs
down.downsample = CausalHybridDownsample3d(
block_in,
spatial_down=spatial_down,
temporal_down=temporal_down,
)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = CausalResnetBlockFactorized3d(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
num_groups=1,
)
self.mid.attn_1 = nn.Sequential(
CausalAttnBlock(block_in, num_groups=1),
CausalTemporalAttnBlock(block_in, num_groups=1),
)
self.mid.block_2 = CausalResnetBlockFactorized3d(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
num_groups=1,
)
# end
self.norm_out = CausalNormalize(block_in, num_groups=1)
self.conv_out = nn.Sequential(
CausalConv3d(
block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1
),
CausalConv3d(
z_channels,
z_channels,
kernel_size=(3, 1, 1),
stride=1,
padding=0,
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patcher3d(x)
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class DecoderFactorized(nn.Module):
def __init__(
self,
out_channels: int,
channels: int,
channels_mult: list[int],
num_res_blocks: int,
attn_resolutions: list[int],
dropout: float,
resolution: int,
z_channels: int,
spatial_compression: int = 8,
temporal_compression: int = 8,
**ignore_kwargs,
):
super().__init__()
self.num_resolutions = len(channels_mult)
self.num_res_blocks = num_res_blocks
# UnPatcher.
patch_size = ignore_kwargs.get("patch_size", 1)
self.unpatcher3d = UnPatcher3D(
patch_size, ignore_kwargs.get("patch_method", "haar")
)
out_ch = out_channels * patch_size * patch_size * patch_size
# calculate the number of upsample operations
self.num_spatial_ups = int(math.log2(spatial_compression)) - int(
math.log2(patch_size)
)
assert (
self.num_spatial_ups <= self.num_resolutions
), f"Spatially upsample {self.num_resolutions} times at most"
self.num_temporal_ups = int(math.log2(temporal_compression)) - int(
math.log2(patch_size)
)
assert (
self.num_temporal_ups <= self.num_resolutions
), f"Temporally upsample {self.num_resolutions} times at most"
block_in = channels * channels_mult[self.num_resolutions - 1]
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logging.debug(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
self.conv_in = nn.Sequential(
CausalConv3d(
z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1
),
CausalConv3d(
block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0
),
)
# middle
self.mid = nn.Module()
self.mid.block_1 = CausalResnetBlockFactorized3d(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
num_groups=1,
)
self.mid.attn_1 = nn.Sequential(
CausalAttnBlock(block_in, num_groups=1),
CausalTemporalAttnBlock(block_in, num_groups=1),
)
self.mid.block_2 = CausalResnetBlockFactorized3d(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
num_groups=1,
)
legacy_mode = ignore_kwargs.get("legacy_mode", False)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = channels * channels_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(
CausalResnetBlockFactorized3d(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
num_groups=1,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(
nn.Sequential(
CausalAttnBlock(block_in, num_groups=1),
CausalTemporalAttnBlock(block_in, num_groups=1),
)
)
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
# The layer index for temporal/spatial downsampling performed
# in the encoder should correspond to the layer index in
# reverse order where upsampling is performed in the decoder.
# If you've a pre-trained model, you can simply finetune.
i_level_reverse = self.num_resolutions - i_level - 1
if legacy_mode:
temporal_up = i_level_reverse < self.num_temporal_ups
else:
temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1
spatial_up = temporal_up or (
i_level_reverse < self.num_spatial_ups
and self.num_spatial_ups > self.num_temporal_ups
)
up.upsample = CausalHybridUpsample3d(
block_in, spatial_up=spatial_up, temporal_up=temporal_up
)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = CausalNormalize(block_in, num_groups=1)
self.conv_out = nn.Sequential(
CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1),
CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0),
)
def forward(self, z):
h = self.conv_in(z)
# middle block.
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# decoder blocks.
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h = self.unpatcher3d(h)
return h
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The patcher and unpatcher implementation for 2D and 3D data.
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
One on the rows and one on the columns.
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
as we need to support downsampling for more than 2x.
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
"""
import torch
import torch.nn.functional as F
from einops import rearrange
_WAVELETS = {
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
"rearrange": torch.tensor([1.0, 1.0]),
}
_PERSISTENT = False
class Patcher(torch.nn.Module):
"""A module to convert image tensors into patches using torch operations.
The main difference from `class Patching` is that this module implements
all operations using torch, rather than python or numpy, for efficiency purpose.
It's bit-wise identical to the Patching module outputs, with the added
benefit of being torch.jit scriptable.
"""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer(
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
)
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=_PERSISTENT,
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
if self.patch_method == "haar":
return self._haar(x)
elif self.patch_method == "rearrange":
return self._arrange(x)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
def _dwt(self, x, mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
if rescale:
out = out / 2
return out
def _haar(self, x):
for _ in self.range:
x = self._dwt(x, rescale=True)
return x
def _arrange(self, x):
x = rearrange(
x,
"b c (h p1) (w p2) -> b (c p1 p2) h w",
p1=self.patch_size,
p2=self.patch_size,
).contiguous()
return x
class Patcher3D(Patcher):
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__(patch_method=patch_method, patch_size=patch_size)
self.register_buffer(
"patch_size_buffer",
patch_size * torch.ones([1], dtype=torch.int32),
persistent=_PERSISTENT,
)
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
# Handles temporal axis.
x = F.pad(
x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode
).to(dtype)
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
# Handles spatial axes.
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
if rescale:
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
return out
def _haar(self, x):
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
for _ in self.range:
x = self._dwt(x, "haar", rescale=True)
return x
def _arrange(self, x):
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
x = rearrange(
x,
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
p1=self.patch_size,
p2=self.patch_size,
p3=self.patch_size,
).contiguous()
return x
class UnPatcher(torch.nn.Module):
"""A module to convert patches into image tensorsusing torch operations.
The main difference from `class Unpatching` is that this module implements
all operations using torch, rather than python or numpy, for efficiency purpose.
It's bit-wise identical to the Unpatching module outputs, with the added
benefit of being torch.jit scriptable.
"""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer(
"wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT
)
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=_PERSISTENT,
)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
if self.patch_method == "haar":
return self._ihaar(x)
elif self.patch_method == "rearrange":
return self._iarrange(x)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
n = h.shape[0]
g = x.shape[1] // 4
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
# Inverse transform.
yl = torch.nn.functional.conv_transpose2d(
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yl += torch.nn.functional.conv_transpose2d(
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh = torch.nn.functional.conv_transpose2d(
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh += torch.nn.functional.conv_transpose2d(
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
y = torch.nn.functional.conv_transpose2d(
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
y += torch.nn.functional.conv_transpose2d(
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
if rescale:
y = y * 2
return y
def _ihaar(self, x):
for _ in self.range:
x = self._idwt(x, "haar", rescale=True)
return x
def _iarrange(self, x):
x = rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
)
return x
class UnPatcher3D(UnPatcher):
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
def __init__(self, patch_size=1, patch_method="haar"):
super().__init__(patch_method=patch_method, patch_size=patch_size)
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
dtype = x.dtype
h = self.wavelets.to(device=x.device)
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1)
hl = hl.to(dtype=dtype)
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
del x
# Height height transposed convolutions.
xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlll
xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xllh
xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhl
xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xlhh
xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhll
xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhlh
xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhl
xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
del xhhh
# Handles width transposed convolutions.
xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xll
xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xlh
xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhl
xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
del xhh
# Handles time axis transposed convolutions.
x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
del xl
x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
if rescale:
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
return x
def _ihaar(self, x):
for _ in self.range:
x = self._idwt(x, "haar", rescale=True)
x = x[:, :, self.patch_size - 1 :, ...]
return x
def _iarrange(self, x):
x = rearrange(
x,
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
p1=self.patch_size,
p2=self.patch_size,
p3=self.patch_size,
)
x = x[:, :, self.patch_size - 1 :, ...]
return x
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for the networks module."""
from typing import Any
import torch
from einops import rearrange
import comfy.ops
ops = comfy.ops.disable_weight_init
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size = x.shape[0]
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size, height = x.shape[0], x.shape[-2]
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
def cast_tuple(t: Any, length: int = 1) -> Any:
return t if isinstance(t, tuple) else ((t,) * length)
def replication_pad(x):
return torch.cat([x[:, :, :1, ...], x], dim=2)
def divisible_by(num: int, den: int) -> bool:
return (num % den) == 0
def is_odd(n: int) -> bool:
return not divisible_by(n, 2)
def nonlinearity(x):
# x * sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class CausalNormalize(torch.nn.Module):
def __init__(self, in_channels, num_groups=1):
super().__init__()
self.norm = ops.GroupNorm(
num_groups=num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True,
)
self.num_groups = num_groups
def forward(self, x):
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
# All new models should use num_groups=1, otherwise causality is not guaranteed.
if self.num_groups == 1:
x, batch_size = time2batch(x)
return batch2time(self.norm(x), batch_size)
return self.norm(x)
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
def log(t, eps=1e-5):
return t.clamp(min=eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
"""
from typing import Optional, Tuple
import torch
from einops import rearrange
from torch import nn
from torchvision import transforms
from enum import Enum
import logging
import comfy.patcher_extension
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
PatchEmbed,
TimestepEmbedding,
Timesteps,
)
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
class DataType(Enum):
IMAGE = "image"
VIDEO = "video"
class GeneralDIT(nn.Module):
"""
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
Args:
max_img_h (int): Maximum height of the input images.
max_img_w (int): Maximum width of the input images.
max_frames (int): Maximum number of frames in the video sequence.
in_channels (int): Number of input channels (e.g., RGB channels for color images).
out_channels (int): Number of output channels.
patch_spatial (tuple): Spatial resolution of patches for input processing.
patch_temporal (int): Temporal resolution of patches for input processing.
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
block_config (str): Configuration of the transformer block. See Notes for supported block types.
model_channels (int): Base number of channels used throughout the model.
num_blocks (int): Number of transformer blocks.
num_heads (int): Number of heads in the multi-head attention layers.
mlp_ratio (float): Expansion ratio for MLP blocks.
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
pos_emb_cls (str): Type of positional embeddings.
pos_emb_learnable (bool): Whether positional embeddings are learnable.
pos_emb_interpolation (str): Method for interpolating positional embeddings.
affline_emb_norm (bool): Whether to normalize affine embeddings.
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
Notes:
Supported block types in block_config:
* cross_attn, ca: Cross attention
* full_attn: Full attention on all flattened tokens
* mlp, ff: Feed forward block
"""
def __init__(
self,
max_img_h: int,
max_img_w: int,
max_frames: int,
in_channels: int,
out_channels: int,
patch_spatial: tuple,
patch_temporal: int,
concat_padding_mask: bool = True,
# attention settings
block_config: str = "FA-CA-MLP",
model_channels: int = 768,
num_blocks: int = 10,
num_heads: int = 16,
mlp_ratio: float = 4.0,
block_x_format: str = "BTHWD",
# cross attention settings
crossattn_emb_channels: int = 1024,
use_cross_attn_mask: bool = False,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
pos_emb_interpolation: str = "crop",
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
use_adaln_lora: bool = False,
adaln_lora_dim: int = 256,
rope_h_extrapolation_ratio: float = 1.0,
rope_w_extrapolation_ratio: float = 1.0,
rope_t_extrapolation_ratio: float = 1.0,
extra_per_block_abs_pos_emb: bool = False,
extra_per_block_abs_pos_emb_type: str = "sincos",
extra_h_extrapolation_ratio: float = 1.0,
extra_w_extrapolation_ratio: float = 1.0,
extra_t_extrapolation_ratio: float = 1.0,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.max_img_h = max_img_h
self.max_img_w = max_img_w
self.max_frames = max_frames
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_spatial = patch_spatial
self.patch_temporal = patch_temporal
self.num_heads = num_heads
self.num_blocks = num_blocks
self.model_channels = model_channels
self.use_cross_attn_mask = use_cross_attn_mask
self.concat_padding_mask = concat_padding_mask
# positional embedding settings
self.pos_emb_cls = pos_emb_cls
self.pos_emb_learnable = pos_emb_learnable
self.pos_emb_interpolation = pos_emb_interpolation
self.affline_emb_norm = affline_emb_norm
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
self.dtype = dtype
weight_args = {"device": device, "dtype": dtype}
in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.x_embedder = PatchEmbed(
spatial_patch_size=patch_spatial,
temporal_patch_size=patch_temporal,
in_channels=in_channels,
out_channels=model_channels,
bias=False,
weight_args=weight_args,
operations=operations,
)
self.build_pos_embed(device=device, dtype=dtype)
self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.ModuleList(
[Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
)
self.blocks = nn.ModuleDict()
for idx in range(num_blocks):
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
x_dim=model_channels,
context_dim=crossattn_emb_channels,
num_heads=num_heads,
block_config=block_config,
mlp_ratio=mlp_ratio,
x_format=self.block_x_format,
use_adaln_lora=use_adaln_lora,
adaln_lora_dim=adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer")
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
else:
self.affline_norm = nn.Identity()
self.final_layer = FinalLayer(
hidden_size=self.model_channels,
spatial_patch_size=self.patch_spatial,
temporal_patch_size=self.patch_temporal,
out_channels=self.out_channels,
use_adaln_lora=self.use_adaln_lora,
adaln_lora_dim=self.adaln_lora_dim,
weight_args=weight_args,
operations=operations,
)
def build_pos_embed(self, device=None, dtype=None):
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
device=device,
)
self.pos_embedder = cls_type(
**kwargs,
)
if self.extra_per_block_abs_pos_emb:
assert self.extra_per_block_abs_pos_emb_type in [
"learnable",
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs,
)
def prepare_embedded_sequence(
self,
x_B_C_T_H_W: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
Args:
x_B_C_T_H_W (torch.Tensor): video
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
If None, a default value (`self.base_fps`) will be used.
padding_mask (Optional[torch.Tensor]): current it is not used
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
- An optional positional embedding tensor, returned only if the positional embedding class
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
Notes:
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
the `self.pos_embedder` with the shape [T, H, W].
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
`self.pos_embedder` with the fps tensor.
- Otherwise, the positional embeddings are generated without considering fps.
"""
if self.concat_padding_mask:
if padding_mask is not None:
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
else:
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else:
extra_pos_emb = None
if "rope" in self.pos_emb_cls.lower():
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
if "fps_aware" in self.pos_emb_cls:
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
else:
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
return x_B_T_H_W_D, None, extra_pos_emb
def decoder_head(
self,
x_B_T_H_W_D: torch.Tensor,
emb_B_D: torch.Tensor,
crossattn_emb: torch.Tensor,
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
crossattn_mask: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
del crossattn_emb, crossattn_mask
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
# This is to ensure x_BT_HW_D has the correct shape because
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
x_BT_HW_D = x_BT_HW_D.view(
B * T_before_patchify // self.patch_temporal,
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
-1,
)
x_B_D_T_H_W = rearrange(
x_BT_HW_D,
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
p1=self.patch_spatial,
p2=self.patch_spatial,
H=H_before_patchify // self.patch_spatial,
W=W_before_patchify // self.patch_spatial,
t=self.patch_temporal,
B=B,
)
return x_B_D_T_H_W
def forward_before_blocks(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
crossattn_emb: torch.Tensor,
crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
crossattn_mask: (B, N) tensor of cross-attention masks
"""
del kwargs
assert isinstance(
data_type, DataType
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
original_shape = x.shape
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
x,
fps=fps,
padding_mask=padding_mask,
latent_condition=latent_condition,
latent_condition_sigma=latent_condition_sigma,
)
# logging affline scale information
affline_scale_log_info = {}
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
affline_emb_B_D = timesteps_B_D
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
if scalar_feature is not None:
raise NotImplementedError("Scalar feature is not implemented yet.")
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
if self.use_cross_attn_mask:
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
else:
crossattn_mask = None
if self.blocks["block0"].x_format == "THWBD":
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
)
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
if crossattn_mask:
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
elif self.blocks["block0"].x_format == "BTHWD":
x = x_B_T_H_W_D
else:
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
output = {
"x": x,
"affline_emb_B_D": affline_emb_B_D,
"crossattn_emb": crossattn_emb,
"crossattn_mask": crossattn_mask,
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
"adaln_lora_B_3D": adaln_lora_B_3D,
"original_shape": original_shape,
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
}
return output
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x,
timesteps,
context,
attention_mask,
fps,
image_size,
padding_mask,
scalar_feature,
data_type,
latent_condition,
latent_condition_sigma,
condition_video_augment_sigma,
**kwargs)
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Args:
x: (B, C, T, H, W) tensor of spatial-temp inputs
timesteps: (B, ) tensor of timesteps
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
crossattn_mask: (B, N) tensor of cross-attention masks
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
we need forward_before_blocks pass to the forward_before_blocks function.
"""
crossattn_emb = context
crossattn_mask = attention_mask
inputs = self.forward_before_blocks(
x=x,
timesteps=timesteps,
crossattn_emb=crossattn_emb,
crossattn_mask=crossattn_mask,
fps=fps,
image_size=image_size,
padding_mask=padding_mask,
scalar_feature=scalar_feature,
data_type=data_type,
latent_condition=latent_condition,
latent_condition_sigma=latent_condition_sigma,
condition_video_augment_sigma=condition_video_augment_sigma,
**kwargs,
)
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
inputs["x"],
inputs["affline_emb_B_D"],
inputs["crossattn_emb"],
inputs["crossattn_mask"],
inputs["rope_emb_L_1_1_D"],
inputs["adaln_lora_B_3D"],
inputs["original_shape"],
)
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
del inputs
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
x = block(
x,
affline_emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
x_B_D_T_H_W = self.decoder_head(
x_B_T_H_W_D=x_B_T_H_W_D,
emb_B_D=affline_emb_B_D,
crossattn_emb=None,
origin_shape=original_shape,
crossattn_mask=None,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x_B_D_T_H_W
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment