Commit 08a21d59 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 初始提交

parent 1a6b26f1
Pipeline #2165 failed with stages
in 0 seconds
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
if TYPE_CHECKING:
from .attention import Attention
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: "Attention",
hidden_states: torch.FloatTensor,
encoder_hidden_states,
attention_mask,
temb = None,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb = None)
# B, L, C
assert hidden_states.ndim == 3, f"Hidden states must be 3-dimensional, got {hidden_states.ndim}"
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2))
hidden_states = hidden_states.transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
hidden_states = attn.to_out(hidden_states)
hidden_states = attn.dropout(hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: "Attention",
hidden_states: torch.FloatTensor,
encoder_hidden_states,
attention_mask,
temb = None,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb = None)
# B, L, C
assert hidden_states.ndim == 3, f"Hidden states must be 3-dimensional, got {hidden_states.ndim}"
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.nheads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2))
hidden_states = hidden_states.transpose(1, 2)
query: torch.Tensor = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key: torch.Tensor = attn.to_k(encoder_hidden_states)
value: torch.Tensor = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.nheads
query = query.view(batch_size, -1, attn.nheads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.nheads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.nheads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.nheads * head_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out(hidden_states)
hidden_states = attn.dropout(hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from .activations import get_activation
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
class CausalConv3d(nn.Conv3d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3, # : int | tuple[int, int, int],
stride=1, # : int | tuple[int, int, int] = 1,
padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
dilation=1, # : int | tuple[int, int, int] = 1,
**kwargs,
):
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
stride = stride if isinstance(stride, tuple) else (stride,) * 3
assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
t_ks, h_ks, w_ks = kernel_size
self.t_stride, h_stride, w_stride = stride
t_dilation, h_dilation, w_dilation = dilation
t_pad = (t_ks - 1) * t_dilation
# TODO: align with SD
if padding is None:
h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
elif isinstance(padding, int):
h_pad = w_pad = padding
else:
assert NotImplementedError
self.temporal_padding = t_pad
self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
self.padding_flag = 0
self.prev_features = None
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=(0, h_pad, w_pad),
**kwargs,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
dtype = x.dtype
x = x.float()
if self.padding_flag == 0:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
x = x.to(dtype=dtype)
return super().forward(x)
elif self.padding_flag == 5:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
x = x.to(dtype=dtype)
self.prev_features = x[:, :, -self.temporal_padding:]
return super().forward(x)
elif self.padding_flag == 6:
if self.t_stride == 2:
x = torch.concat(
[self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2
)
else:
x = torch.concat(
[self.prev_features, x], dim = 2
)
self.prev_features = x[:, :, -self.temporal_padding:]
x = x.to(dtype=dtype)
return super().forward(x)
else:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),
)
x = x.to(dtype=dtype)
return super().forward(x)
class ResidualBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
non_linearity: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
):
super().__init__()
self.output_scale_factor = output_scale_factor
self.norm1 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=in_channels,
eps=norm_eps,
affine=True,
)
self.nonlinearity = get_activation(non_linearity)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=out_channels,
eps=norm_eps,
affine=True,
)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.shortcut = nn.Identity()
self.set_3dgroupnorm = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(x)
if self.set_3dgroupnorm:
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm1(x)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
else:
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
if self.set_3dgroupnorm:
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm2(x)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
else:
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.conv2(x)
return (x + shortcut) / self.output_scale_factor
class ResidualBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
non_linearity: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
):
super().__init__()
self.output_scale_factor = output_scale_factor
self.norm1 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=in_channels,
eps=norm_eps,
affine=True,
)
self.nonlinearity = get_activation(non_linearity)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=out_channels,
eps=norm_eps,
affine=True,
)
self.dropout = nn.Dropout(dropout)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3)
if in_channels != out_channels:
self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.shortcut = nn.Identity()
self.set_3dgroupnorm = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(x)
if self.set_3dgroupnorm:
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm1(x)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
else:
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
if self.set_3dgroupnorm:
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm2(x)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
else:
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.conv2(x)
return (x + shortcut) / self.output_scale_factor
class SpatialNorm2D(nn.Module):
"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
):
super().__init__()
self.norm = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.set_3dgroupnorm = False
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
if self.set_3dgroupnorm:
batch_size = f.shape[0]
f = rearrange(f, "b c t h w -> (b t) c h w")
norm_f = self.norm(f)
norm_f = rearrange(norm_f, "(b t) c h w -> b c t h w", b=batch_size)
else:
norm_f = self.norm(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class SpatialNorm3D(SpatialNorm2D):
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
batch_size = f.shape[0]
f = rearrange(f, "b c t h w -> (b t) c h w")
zq = rearrange(zq, "b c t h w -> (b t) c h w")
x = super().forward(f, zq)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
return x
import math
import torch
import torch.nn as nn
from .downsamplers import BlurPooling2D, BlurPooling3D
class DiscriminatorBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.output_scale_factor = output_scale_factor
self.norm1 = nn.BatchNorm2d(in_channels)
self.nonlinearity = nn.LeakyReLU(0.2)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
if add_downsample:
self.downsampler = BlurPooling2D(out_channels, out_channels)
else:
self.downsampler = nn.Identity()
self.norm2 = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if add_downsample:
self.shortcut = nn.Sequential(
BlurPooling2D(in_channels, in_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)
else:
self.shortcut = nn.Identity()
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(x)
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.downsampler(x)
x = self.conv2(x)
return (x + shortcut) / self.output_scale_factor
class Discriminator2D(nn.Module):
def __init__(
self,
in_channels: int = 3,
block_out_channels = (64,),
):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, out_channels in enumerate(block_out_channels):
input_channels = output_channels
output_channels = out_channels
is_final_block = i == len(block_out_channels) - 1
self.blocks.append(
DiscriminatorBlock2D(
in_channels=input_channels,
out_channels=output_channels,
output_scale_factor=math.sqrt(2),
add_downsample=not is_final_block,
)
)
self.conv_norm_out = nn.BatchNorm2d(block_out_channels[-1])
self.conv_act = nn.LeakyReLU(0.2)
self.conv_out = nn.Conv2d(block_out_channels[-1], 1, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, H, W)
x = self.conv_in(x)
for block in self.blocks:
x = block(x)
x = self.conv_out(x)
return x
class DiscriminatorBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.output_scale_factor = output_scale_factor
self.norm1 = nn.GroupNorm(32, in_channels)
self.nonlinearity = nn.LeakyReLU(0.2)
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
if add_downsample:
self.downsampler = BlurPooling3D(out_channels, out_channels)
else:
self.downsampler = nn.Identity()
self.norm2 = nn.GroupNorm(32, out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
if add_downsample:
self.shortcut = nn.Sequential(
BlurPooling3D(in_channels, in_channels),
nn.Conv3d(in_channels, out_channels, kernel_size=1),
)
else:
self.shortcut = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=1),
)
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 2
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(x)
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.downsampler(x)
x = self.conv2(x)
return (x + shortcut) / self.output_scale_factor
class Discriminator3D(nn.Module):
def __init__(
self,
in_channels: int = 3,
block_out_channels = (64,),
):
super().__init__()
self.conv_in = nn.Conv3d(in_channels, block_out_channels[0], kernel_size=3, padding=1, stride=2)
self.blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, out_channels in enumerate(block_out_channels):
input_channels = output_channels
output_channels = out_channels
is_final_block = i == len(block_out_channels) - 1
self.blocks.append(
DiscriminatorBlock3D(
in_channels=input_channels,
out_channels=output_channels,
output_scale_factor=math.sqrt(2),
add_downsample=not is_final_block,
)
)
self.conv_norm_out = nn.GroupNorm(32, block_out_channels[-1])
self.conv_act = nn.LeakyReLU(0.2)
self.conv_out = nn.Conv3d(block_out_channels[-1], 1, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
x = self.conv_in(x)
for block in self.blocks:
x = block(x)
x = self.conv_out(x)
return x
import torch
import torch.nn as nn
from .attention import SpatialAttention, TemporalAttention
from .common import ResidualBlock3D
from .downsamplers import (SpatialDownsampler3D, SpatialTemporalDownsampler3D,
TemporalDownsampler3D)
from .gc_block import GlobalContextBlock
def get_down_block(
down_block_type: str,
in_channels: int,
out_channels: int,
num_layers: int,
act_fn: str,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
) -> nn.Module:
if down_block_type == "DownBlock3D":
return DownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
)
elif down_block_type == "SpatialDownBlock3D":
return SpatialDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "SpatialAttnDownBlock3D":
return SpatialAttnDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "TemporalDownBlock3D":
return TemporalDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "TemporalAttnDownBlock3D":
return TemporalAttnDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "SpatialTemporalDownBlock3D":
return SpatialTemporalDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
else:
raise ValueError(f"Unknown down block type: {down_block_type}")
class DownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
return x
class SpatialDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class SpatialTemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialTemporalDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class SpatialAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
SpatialAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
TemporalAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import CausalConv3d
class Downsampler(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
spatial_downsample_factor: int = 1,
temporal_downsample_factor: int = 1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.spatial_downsample_factor = spatial_downsample_factor
self.temporal_downsample_factor = temporal_downsample_factor
class SpatialDownsampler3D(Downsampler):
def __init__(self, in_channels: int, out_channels):
if out_channels is None:
out_channels = in_channels
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
spatial_downsample_factor=2,
temporal_downsample_factor=1,
)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(1, 2, 2),
padding=0,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, (0, 1, 0, 1))
return self.conv(x)
class TemporalDownsampler3D(Downsampler):
def __init__(self, in_channels: int, out_channels):
if out_channels is None:
out_channels = in_channels
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
spatial_downsample_factor=1,
temporal_downsample_factor=2,
)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 1, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class SpatialTemporalDownsampler3D(Downsampler):
def __init__(self, in_channels: int, out_channels):
if out_channels is None:
out_channels = in_channels
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
spatial_downsample_factor=2,
temporal_downsample_factor=2,
)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=0,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, (0, 1, 0, 1))
return self.conv(x)
class BlurPooling2D(Downsampler):
def __init__(self, in_channels: int, out_channels):
if out_channels is None:
out_channels = in_channels
assert in_channels == out_channels
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
spatial_downsample_factor=2,
temporal_downsample_factor=1,
)
filt = torch.tensor([1, 2, 1], dtype=torch.float32)
filt = torch.einsum("i,j -> ij", filt, filt)
filt = filt / filt.sum()
filt = filt[None, None].repeat(out_channels, 1, 1, 1)
self.register_buffer("filt", filt)
self.filt: torch.Tensor
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, H, W)
return F.conv2d(x, self.filt, stride=2, padding=1, groups=self.in_channels)
class BlurPooling3D(Downsampler):
def __init__(self, in_channels: int, out_channels):
if out_channels is None:
out_channels = in_channels
assert in_channels == out_channels
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
spatial_downsample_factor=2,
temporal_downsample_factor=2,
)
filt = torch.tensor([1, 2, 1], dtype=torch.float32)
filt = torch.einsum("i,j,k -> ijk", filt, filt, filt)
filt = filt / filt.sum()
filt = filt[None, None].repeat(out_channels, 1, 1, 1, 1)
self.register_buffer("filt", filt)
self.filt: torch.Tensor
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
return F.conv3d(x, self.filt, stride=2, padding=1, groups=self.in_channels)
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class GlobalContextBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
min_channels: int = 16,
init_bias: float = -10.,
fusion_type: str = "mul",
):
super().__init__()
assert fusion_type in ("mul", "add"), f"Unsupported fusion type: {fusion_type}"
self.fusion_type = fusion_type
self.conv_ctx = nn.Conv2d(in_channels, 1, kernel_size=1)
num_channels = max(min_channels, out_channels // 2)
if fusion_type == "mul":
self.conv_mul = nn.Sequential(
nn.Conv2d(in_channels, num_channels, kernel_size=1),
nn.LayerNorm([num_channels, 1, 1]), # TODO: LayerNorm or GroupNorm?
nn.LeakyReLU(0.1),
nn.Conv2d(num_channels, out_channels, kernel_size=1),
nn.Sigmoid(),
)
nn.init.zeros_(self.conv_mul[-2].weight)
nn.init.constant_(self.conv_mul[-2].bias, init_bias)
else:
self.conv_add = nn.Sequential(
nn.Conv2d(in_channels, num_channels, kernel_size=1),
nn.LayerNorm([num_channels, 1, 1]), # TODO: LayerNorm or GroupNorm?
nn.LeakyReLU(0.1),
nn.Conv2d(num_channels, out_channels, kernel_size=1),
)
nn.init.zeros_(self.conv_add[-1].weight)
nn.init.constant_(self.conv_add[-1].bias, init_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
is_image = x.ndim == 4
if is_image:
x = rearrange(x, "b c h w -> b c 1 h w")
# x: (B, C, T, H, W)
orig_x = x
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
ctx = self.conv_ctx(x)
ctx = rearrange(ctx, "b c h w -> b c (h w)")
ctx = F.softmax(ctx, dim=-1)
flattened_x = rearrange(x, "b c h w -> b c (h w)")
x = torch.einsum("b c1 n, b c2 n -> b c2 c1", ctx, flattened_x)
x = rearrange(x, "... -> ... 1")
if self.fusion_type == "mul":
mul_term = self.conv_mul(x)
mul_term = rearrange(mul_term, "(b t) c h w -> b c t h w", b=batch_size)
x = orig_x * mul_term
else:
add_term = self.conv_add(x)
add_term = rearrange(add_term, "(b t) c h w -> b c t h w", b=batch_size)
x = orig_x + add_term
if is_image:
x = rearrange(x, "b c 1 h w -> b c h w")
return x
import torch
import torch.nn as nn
from .attention import Attention3D, SpatialAttention, TemporalAttention
from .common import ResidualBlock3D
def get_mid_block(
mid_block_type: str,
in_channels: int,
num_layers: int,
act_fn: str,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
add_attention: bool = True,
attention_type: str = "3d",
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
) -> nn.Module:
if mid_block_type == "MidBlock3D":
return MidBlock3D(
in_channels=in_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
add_attention=add_attention,
attention_type=attention_type,
attention_head_dim=in_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
)
else:
raise ValueError(f"Unknown mid block type: {mid_block_type}")
class MidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`MidBlock3D`] with multiple residual blocks and optional attention blocks.
Args:
in_channels (`int`): The number of input channels.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
attention_type: (`str`, *optional*, defaults to `3d`): The type of attention to use. Defaults to `3d`.
attention_head_dim (`int`, *optional*, defaults to 1):
Dimension of a single attention head. The number of attention heads is determined based on this value and
the number of input channels.
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, temporal_length, height, width)`.
"""
def __init__(
self,
in_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
add_attention: bool = True,
attention_type: str = "3d",
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
super().__init__()
self.attention_type = attention_type
norm_num_groups = norm_num_groups if norm_num_groups is not None else min(in_channels // 4, 32)
self.convs = nn.ModuleList([
ResidualBlock3D(
in_channels=in_channels,
out_channels=in_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
])
self.attentions = nn.ModuleList([])
for _ in range(num_layers - 1):
if add_attention:
if attention_type == "3d":
self.attentions.append(
Attention3D(
in_channels,
nheads=in_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
elif attention_type == "spatial_temporal":
self.attentions.append(
nn.ModuleList([
SpatialAttention(
in_channels,
nheads=in_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
),
TemporalAttention(
in_channels,
nheads=in_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
),
])
)
elif attention_type == "spatial":
self.attentions.append(
SpatialAttention(
in_channels,
nheads=in_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
elif attention_type == "temporal":
self.attentions.append(
TemporalAttention(
in_channels,
nheads=in_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
else:
raise ValueError(f"Unknown attention type: {attention_type}")
else:
self.attentions.append(None)
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=in_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.convs[0](hidden_states)
for attn, resnet in zip(self.attentions, self.convs[1:]):
if attn is not None:
if self.attention_type == "spatial_temporal":
spatial_attn, temporal_attn = attn
hidden_states = spatial_attn(hidden_states)
hidden_states = temporal_attn(hidden_states)
else:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states)
return hidden_states
import torch
import torch.nn as nn
from .attention import SpatialAttention, TemporalAttention
from .common import ResidualBlock3D
from .gc_block import GlobalContextBlock
from .upsamplers import (SpatialTemporalUpsampler3D, SpatialUpsampler3D,
TemporalUpsampler3D)
def get_up_block(
up_block_type: str,
in_channels: int,
out_channels: int,
num_layers: int,
act_fn: str,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
) -> nn.Module:
if up_block_type == "SpatialUpBlock3D":
return SpatialUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
elif up_block_type == "SpatialAttnUpBlock3D":
return SpatialAttnUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
elif up_block_type == "TemporalUpBlock3D":
return TemporalUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
elif up_block_type == "TemporalAttnUpBlock3D":
return TemporalAttnUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
elif up_block_type == "SpatialTemporalUpBlock3D":
return SpatialTemporalUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
else:
raise ValueError(f"Unknown up block type: {up_block_type}")
class SpatialUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
if add_upsample:
self.upsampler = SpatialUpsampler3D(in_channels, in_channels)
else:
self.upsampler = None
if add_gc_block:
self.gc_block = GlobalContextBlock(in_channels, in_channels, fusion_type="mul")
else:
self.gc_block = None
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
class SpatialAttnUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
SpatialAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_upsample:
self.upsampler = SpatialUpsampler3D(out_channels, out_channels)
else:
self.upsampler = None
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
class TemporalUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_upsample:
self.upsampler = TemporalUpsampler3D(out_channels, out_channels)
else:
self.upsampler = None
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
class TemporalAttnUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
TemporalAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_upsample:
self.upsampler = TemporalUpsampler3D(out_channels, out_channels)
else:
self.upsampler = None
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
class SpatialTemporalUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_upsample:
self.upsampler = SpatialTemporalUpsampler3D(out_channels, out_channels)
else:
self.upsampler = None
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from .common import CausalConv3d
class Upsampler(nn.Module):
def __init__(
self,
spatial_upsample_factor: int = 1,
temporal_upsample_factor: int = 1,
):
super().__init__()
self.spatial_upsample_factor = spatial_upsample_factor
self.temporal_upsample_factor = temporal_upsample_factor
class SpatialUpsampler3D(Upsampler):
def __init__(self, in_channels: int, out_channels: int):
super().__init__(spatial_upsample_factor=2)
if out_channels is None:
out_channels = in_channels
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest")
x = self.conv(x)
return x
class SpatialUpsamplerD2S3D(Upsampler):
def __init__(self, in_channels: int, out_channels: int):
super().__init__(spatial_upsample_factor=2)
if out_channels is None:
out_channels = in_channels
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=3,
)
o, i, t, h, w = self.conv.weight.shape
conv_weight = torch.empty(o // 4, i, t, h, w)
nn.init.kaiming_normal_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
self.conv.weight.data.copy_(conv_weight)
nn.init.zeros_(self.conv.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = rearrange(x, "b (c p1 p2) t h w -> b c t (h p1) (w p2)", p1=2, p2=2)
return x
class TemporalUpsampler3D(Upsampler):
def __init__(self, in_channels: int, out_channels: int):
super().__init__(
spatial_upsample_factor=1,
temporal_upsample_factor=2,
)
if out_channels is None:
out_channels = in_channels
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[2] > 1:
first_frame, x = x[:, :, :1], x[:, :, 1:]
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
x = torch.cat([first_frame, x], dim=2)
x = self.conv(x)
return x
class TemporalUpsamplerD2S3D(Upsampler):
def __init__(self, in_channels: int, out_channels: int):
super().__init__(
spatial_upsample_factor=1,
temporal_upsample_factor=2,
)
if out_channels is None:
out_channels = in_channels
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels * 2,
kernel_size=3,
)
o, i, t, h, w = self.conv.weight.shape
conv_weight = torch.empty(o // 2, i, t, h, w)
nn.init.kaiming_normal_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
self.conv.weight.data.copy_(conv_weight)
nn.init.zeros_(self.conv.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = rearrange(x, "b (c p1) t h w -> b c (t p1) h w", p1=2)
x = x[:, :, 1:]
return x
class SpatialTemporalUpsampler3D(Upsampler):
def __init__(self, in_channels: int, out_channels: int):
super().__init__(
spatial_upsample_factor=2,
temporal_upsample_factor=2,
)
if out_channels is None:
out_channels = in_channels
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
)
self.padding_flag = 0
self.set_3dgroupnorm = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest")
x = self.conv(x)
if self.padding_flag == 0:
if x.shape[2] > 1:
first_frame, x = x[:, :, :1], x[:, :, 1:]
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest")
x = torch.cat([first_frame, x], dim=2)
elif self.padding_flag == 2 or self.padding_flag == 5 or self.padding_flag == 6:
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest")
return x
class SpatialTemporalUpsamplerD2S3D(Upsampler):
def __init__(self, in_channels: int, out_channels: int):
super().__init__(
spatial_upsample_factor=2,
temporal_upsample_factor=2,
)
if out_channels is None:
out_channels = in_channels
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels * 8,
kernel_size=3,
)
o, i, t, h, w = self.conv.weight.shape
conv_weight = torch.empty(o // 8, i, t, h, w)
nn.init.kaiming_normal_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 8) ...")
self.conv.weight.data.copy_(conv_weight)
nn.init.zeros_(self.conv.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = rearrange(x, "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", p1=2, p2=2, p3=2)
x = x[:, :, 1:]
return x
import importlib
import multiprocessing as mp
from collections import abc
from functools import partial
from inspect import isfunction
from queue import Queue
from threading import Thread
import numpy as np
import torch
from einops import rearrange
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
# create dummy dataset instance
# run prefetching
if idx_to_fn:
res = func(data, worker_id=idx)
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")
def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
Q = mp.Queue(1000)
proc = mp.Process
else:
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
]
else:
step = (
int(len(data) / n_proc + 1)
if len(data) % n_proc != 0
else int(len(data) / n_proc)
)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i: i + step] for i in range(0, len(data), step)]
)
]
processes = []
for i in range(n_proc):
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
processes += [p]
# start processes
print(f"Start prefetching...")
import time
start = time.time()
gather_res = [[] for _ in range(n_proc)]
try:
for p in processes:
p.start()
k = 0
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
for p in processes:
p.terminate()
raise e
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == 'list':
out = []
for r in gather_res:
out.extend(r)
return out
else:
return gather_res
from setuptools import find_packages, setup
setup(
name='latent-diffusion',
version='0.0.1',
description='',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
'tqdm',
],
)
\ No newline at end of file
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download(
repo_id="TencentARC/NVComposer", filename="NVComposer-V0.1.ckpt", repo_type="model", local_dir="./models"
)
#!/bin/bash
cd /root/Ruyi-Models
python app.py
{
"cells": [
{
"cell_type": "markdown",
"id": "e5c5a211-2ccd-4341-af10-ac546484b91f",
"metadata": {
"tags": []
},
"source": [
"## 项目介绍\n",
"- 原项目地址:https://huggingface.co/IamCreateAI/Ruyi-Mini-7B\n",
"- Ruyi-Mini-7B是一种开源图像转视频生成模型。从输入图像开始,Ruyi生成分辨率从360p到720p的后续视频帧,支持各种宽高比,最长持续时间为5秒。通过运动和摄像头控制增强,Ruyi在视频生成方面提供了更大的灵活性和创造力。\n",
"- 项目在L20显卡,cuda12.2上进行适配\n",
"## 使用说明\n",
"- 启动和重启 Notebook 点上方工具栏中的「重启并运行所有单元格」。出现如下内容就算成功了:\n",
" - `Running on local URL: http://0.0.0.0:7860`\n",
" - `Running on public URL: https://xxxxxxxxxxxxxxx.gradio.live`\n",
"- 通过以下方式开启页面:\n",
" - 控制台打开「自定义服务」了,访问自定义服务端口号设置为7860\n",
" - 直接打开显示的公开链接`public URL`\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53a96614-e2d2-4710-a82b-0d5ca9cb9872",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# 启动\n",
"!sh start.sh"
]
},
{
"cell_type": "markdown",
"source": [
"---\n",
"**扫码关注公众号,获取更多资讯**<br>\n",
"<div align=center>\n",
"<img src=\"assets/二维码.jpeg\" width = 20% />\n",
"</div>\n"
],
"metadata": {
"collapsed": false
},
"id": "2f54158c2967bc25"
},
{
"cell_type": "code",
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "6dc59fbbcf222b6b"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment