Commit 5c023842 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 增加LatentSync

parent 822b66ca
Pipeline #2211 canceled with stages
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from torch.utils.data import Dataset
import torch
import random
import cv2
from ..utils.image_processor import ImageProcessor, load_fixed_mask
from ..utils.audio import melspectrogram
from decord import AudioReader, VideoReader, cpu
class UNetDataset(Dataset):
def __init__(self, train_data_dir: str, config):
if config.data.train_fileslist != "":
with open(config.data.train_fileslist) as file:
self.video_paths = [line.rstrip() for line in file]
elif train_data_dir != "":
self.video_paths = []
for file in os.listdir(train_data_dir):
if file.endswith(".mp4"):
self.video_paths.append(os.path.join(train_data_dir, file))
else:
raise ValueError("data_dir and fileslist cannot be both empty")
self.resolution = config.data.resolution
self.num_frames = config.data.num_frames
if self.num_frames == 16:
self.mel_window_length = 52
elif self.num_frames == 5:
self.mel_window_length = 16
else:
raise NotImplementedError("Only support 16 and 5 frames now")
self.audio_sample_rate = config.data.audio_sample_rate
self.video_fps = config.data.video_fps
self.mask = config.data.mask
self.mask_image = load_fixed_mask(self.resolution)
self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
def __len__(self):
return len(self.video_paths)
def read_audio(self, video_path: str):
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
return torch.from_numpy(original_mel)
def crop_audio_window(self, original_mel, start_index):
start_idx = int(80.0 * (start_index / float(self.video_fps)))
end_idx = start_idx + self.mel_window_length
return original_mel[:, start_idx:end_idx].unsqueeze(0)
def get_frames(self, video_reader: VideoReader):
total_num_frames = len(video_reader)
start_idx = random.randint(self.num_frames // 2, total_num_frames - self.num_frames - self.num_frames // 2)
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
while True:
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
if wrong_start_idx > start_idx - self.num_frames and wrong_start_idx < start_idx + self.num_frames:
continue
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
break
frames = video_reader.get_batch(frames_index).asnumpy()
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
return frames, wrong_frames, start_idx
def worker_init_fn(self, worker_id):
# Initialize the face mesh object in each worker process,
# because the face mesh object cannot be called in subprocesses
self.worker_id = worker_id
setattr(
self,
f"image_processor_{worker_id}",
ImageProcessor(self.resolution, self.mask, mask_image=self.mask_image),
)
def __getitem__(self, idx):
image_processor = getattr(self, f"image_processor_{self.worker_id}")
while True:
try:
idx = random.randint(0, len(self) - 1)
# Get video file path
video_path = self.video_paths[idx]
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
if len(vr) < 3 * self.num_frames:
continue
continuous_frames, ref_frames, start_idx = self.get_frames(vr)
if self.load_audio_data:
mel_cache_path = os.path.join(
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
)
if os.path.isfile(mel_cache_path):
try:
original_mel = torch.load(mel_cache_path)
except Exception as e:
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
os.remove(mel_cache_path)
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
else:
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
mel = self.crop_audio_window(original_mel, start_idx)
if mel.shape[-1] != self.mel_window_length:
continue
else:
mel = []
gt, masked_gt, mask = image_processor.prepare_masks_and_masked_images(
continuous_frames, affine_transform=False
)
if self.mask == "fix_mask":
ref, _, _ = image_processor.prepare_masks_and_masked_images(ref_frames, affine_transform=False)
else:
ref = image_processor.process_images(ref_frames)
vr.seek(0) # avoid memory leak
break
except Exception as e: # Handle the exception of face not detcted
print(f"{type(e).__name__} - {e} - {video_path}")
if "vr" in locals():
vr.seek(0) # avoid memory leak
sample = dict(
gt=gt,
masked_gt=masked_gt,
ref=ref,
mel=mel,
mask=mask,
video_path=video_path,
start_idx=start_idx,
)
return sample
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
from dataclasses import dataclass
from turtle import forward
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
from einops import rearrange, repeat
from .utils import zero_module
@dataclass
class Transformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class Transformer3DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
use_motion_module: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
add_audio_layer=False,
audio_condition_method="cross_attn",
custom_audio_layer: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
if not custom_audio_layer:
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_motion_module=use_motion_module,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
add_audio_layer=add_audio_layer,
custom_audio_layer=custom_audio_layer,
audio_condition_method=audio_condition_method,
)
for d in range(num_layers)
]
)
else:
self.transformer_blocks = nn.ModuleList(
[
AudioTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_motion_module=use_motion_module,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
add_audio_layer=add_audio_layer,
)
for d in range(num_layers)
]
)
# 4. Define output layers
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
if custom_audio_layer:
self.proj_out = zero_module(self.proj_out)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# Input
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
# No need to do this for audio input, because different audio samples are independent
# encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
video_length=video_length,
)
# Output
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
use_motion_module: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
add_audio_layer=False,
custom_audio_layer=False,
audio_condition_method="cross_attn",
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
self.use_motion_module = use_motion_module
self.add_audio_layer = add_audio_layer
# SC-Attn
assert unet_use_cross_frame_attention is not None
if unet_use_cross_frame_attention:
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
else:
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# Cross-Attn
if add_audio_layer and audio_condition_method == "cross_attn" and not custom_audio_layer:
self.audio_cross_attn = AudioCrossAttn(
dim=dim,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
num_embeds_ada_norm=num_embeds_ada_norm,
use_ada_layer_norm=self.use_ada_layer_norm,
zero_proj_out=False,
)
else:
self.audio_cross_attn = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
# Temp-Attn
assert unet_use_temporal_attention is not None
if unet_use_temporal_attention:
self.attn_temp = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
if self.audio_cross_attn is not None:
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
use_memory_efficient_attention_xformers
)
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
):
# SparseCausal-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
# if self.only_cross_attention:
# hidden_states = (
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
# )
# else:
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
# pdb.set_trace()
if self.unet_use_cross_frame_attention:
hidden_states = (
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
+ hidden_states
)
else:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
hidden_states = self.audio_cross_attn(
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
class AudioTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
use_motion_module: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
add_audio_layer=False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
self.use_motion_module = use_motion_module
self.add_audio_layer = add_audio_layer
# SC-Attn
assert unet_use_cross_frame_attention is not None
if unet_use_cross_frame_attention:
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
else:
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
self.audio_cross_attn = AudioCrossAttn(
dim=dim,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
num_embeds_ada_norm=num_embeds_ada_norm,
use_ada_layer_norm=self.use_ada_layer_norm,
zero_proj_out=False,
)
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
if self.audio_cross_attn is not None:
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
use_memory_efficient_attention_xformers
)
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
):
# SparseCausal-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
# pdb.set_trace()
if self.unet_use_cross_frame_attention:
hidden_states = (
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
+ hidden_states
)
else:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
hidden_states = self.audio_cross_attn(
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class AudioCrossAttn(nn.Module):
def __init__(
self,
dim,
cross_attention_dim,
num_attention_heads,
attention_head_dim,
dropout,
attention_bias,
upcast_attention,
num_embeds_ada_norm,
use_ada_layer_norm,
zero_proj_out=False,
):
super().__init__()
self.norm = AdaLayerNorm(dim, num_embeds_ada_norm) if use_ada_layer_norm else nn.LayerNorm(dim)
self.attn = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
if zero_proj_out:
self.proj_out = zero_module(nn.Linear(dim, dim))
self.zero_proj_out = zero_proj_out
self.use_ada_layer_norm = use_ada_layer_norm
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
previous_hidden_states = hidden_states
hidden_states = self.norm(hidden_states, timestep) if self.use_ada_layer_norm else self.norm(hidden_states)
if encoder_hidden_states.dim() == 4:
encoder_hidden_states = rearrange(encoder_hidden_states, "b f n d -> (b f) n d")
hidden_states = self.attn(
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
if self.zero_proj_out:
hidden_states = self.proj_out(hidden_states)
return hidden_states + previous_hidden_states
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
# Actually we don't use the motion module in the final version of LatentSync
# When we started the project, we used the codebase of AnimateDiff and tried motion module
# But the results are poor, and we decied to leave the code here for possible future usage
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import CrossAttention, FeedForward
from einops import rearrange, repeat
import math
from .utils import zero_module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(
in_channels=in_channels,
**motion_module_kwargs,
)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads=8,
num_transformer_block=2,
attention_block_types=("Temporal_Self", "Temporal_Self"),
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
temporal_attention_dim_div=1,
zero_initialize=True,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = (
attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
video_length=video_length,
)
+ hidden_states
)
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.0, max_len=24):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
class VersatileAttention(CrossAttention):
def __init__(
self,
attention_mode=None,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.pos_encoder = (
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
if (temporal_position_encoding and attention_mode == "Temporal")
else None
)
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
batch_size, sequence_length, _ = hidden_states.shape
if self.attention_mode == "Temporal":
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
encoder_hidden_states = (
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
if encoder_hidden_states is not None
else encoder_hidden_states
)
else:
raise NotImplementedError
# encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# if self.use_conv:
# if self.name == "conv":
# hidden_states = self.conv(hidden_states)
# else:
# hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.conv(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
raise NotImplementedError
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
assert use_inflated_groupnorm != None
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
time_emb_proj_out_channels = out_channels
# if self.time_embedding_norm == "default":
# time_emb_proj_out_channels = out_channels
# elif self.time_embedding_norm == "scale_shift":
# time_emb_proj_out_channels = out_channels * 2
# else:
# raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
if self.time_embedding_norm == "scale_shift":
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
else:
self.double_len_linear = None
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
if temb.dim() == 2:
# input (1, 1280)
temb = self.time_emb_proj(self.nonlinearity(temb))
temb = temb[:, :, None, None, None] # unsqueeze
else:
# input (1, 1280, 16)
temb = temb.permute(0, 2, 1)
temb = self.time_emb_proj(self.nonlinearity(temb))
if self.double_len_linear is not None:
temb = self.double_len_linear(self.nonlinearity(temb))
temb = temb.permute(0, 2, 1)
temb = temb[:, :, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from einops import rearrange
from torch.nn import functional as F
from ..utils.util import cosine_loss
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import CrossAttention, FeedForward
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
class SyncNet(nn.Module):
def __init__(self, config):
super().__init__()
self.audio_encoder = DownEncoder2D(
in_channels=config["audio_encoder"]["in_channels"],
block_out_channels=config["audio_encoder"]["block_out_channels"],
downsample_factors=config["audio_encoder"]["downsample_factors"],
dropout=config["audio_encoder"]["dropout"],
attn_blocks=config["audio_encoder"]["attn_blocks"],
)
self.visual_encoder = DownEncoder2D(
in_channels=config["visual_encoder"]["in_channels"],
block_out_channels=config["visual_encoder"]["block_out_channels"],
downsample_factors=config["visual_encoder"]["downsample_factors"],
dropout=config["visual_encoder"]["dropout"],
attn_blocks=config["visual_encoder"]["attn_blocks"],
)
self.eval()
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
norm_num_groups: int = 32,
eps: float = 1e-6,
act_fn: str = "silu",
downsample_factor=2,
):
super().__init__()
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
if isinstance(downsample_factor, list):
downsample_factor = tuple(downsample_factor)
if downsample_factor == 1:
self.downsample_conv = None
else:
self.downsample_conv = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
)
self.pad = (0, 1, 0, 1)
if isinstance(downsample_factor, tuple):
if downsample_factor[0] == 1:
self.pad = (0, 1, 1, 1) # The padding order is from back to front
elif downsample_factor[1] == 1:
self.pad = (1, 1, 0, 1)
def forward(self, input_tensor):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
hidden_states += input_tensor
if self.downsample_conv is not None:
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
hidden_states = self.downsample_conv(hidden_states)
return hidden_states
class AttentionBlock2D(nn.Module):
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
super().__init__()
if not is_xformers_available():
raise ModuleNotFoundError(
"You have to install xformers to enable memory efficient attetion", name="xformers"
)
# inner_dim = dim_head * heads
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
self.norm2 = nn.LayerNorm(query_dim)
self.norm3 = nn.LayerNorm(query_dim)
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
self.attn._use_memory_efficient_attention_xformers = True
def forward(self, hidden_states):
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.conv_in(hidden_states)
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DownEncoder2D(nn.Module):
def __init__(
self,
in_channels=4 * 16,
block_out_channels=[64, 128, 256, 256],
downsample_factors=[2, 2, 2, 2],
layers_per_block=2,
norm_num_groups=32,
attn_blocks=[1, 1, 1, 1],
dropout: float = 0.0,
act_fn="silu",
):
super().__init__()
self.layers_per_block = layers_per_block
# in
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# down
self.down_blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, block_out_channel in enumerate(block_out_channels):
input_channels = output_channels
output_channels = block_out_channel
# is_final_block = i == len(block_out_channels) - 1
down_block = ResnetBlock2D(
in_channels=input_channels,
out_channels=output_channels,
downsample_factor=downsample_factors[i],
norm_num_groups=norm_num_groups,
dropout=dropout,
act_fn=act_fn,
)
self.down_blocks.append(down_block)
if attn_blocks[i] == 1:
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
self.down_blocks.append(attention_block)
# out
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.act_fn_out = nn.ReLU()
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
# down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.act_fn_out(hidden_states)
return hidden_states
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
# The code here is for ablation study.
from torch import nn
from torch.nn import functional as F
class SyncNetWav2Lip(nn.Module):
def __init__(self, act_fn="leaky"):
super().__init__()
# input image sequences: (15, 128, 256)
self.visual_encoder = nn.Sequential(
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
)
# input audio sequences: (1, 80, 16)
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
)
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "tanh":
self.act_fn = nn.Tanh()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
elif act_fn == "leaky":
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act_fn(out)
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import copy
import torch
import torch.nn as nn
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers import UNet2DConditionModel
from diffusers.utils import BaseOutput, logging
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
from .resnet import InflatedConv3d, InflatedGroupNorm
from ..utils.util import zero_rank_log
from einops import rearrange
from .utils import zero_module
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNet3DConditionOutput(BaseOutput):
sample: torch.FloatTensor
class UNet3DConditionModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
mid_block_type: str = "UNetMidBlock3DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
use_inflated_groupnorm=False,
# Additional
use_motion_module=False,
motion_module_resolutions=(1, 2, 4, 8),
motion_module_mid_block=False,
motion_module_decoder_only=False,
motion_module_type=None,
motion_module_kwargs={},
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
add_audio_layer=False,
audio_condition_method: str = "cross_attn",
custom_audio_layer=False,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
self.use_motion_module = use_motion_module
self.add_audio_layer = add_audio_layer
self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
# time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
res = 2**i
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module
and (res in motion_module_resolutions)
and (not motion_module_decoder_only),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
custom_audio_layer=custom_audio_layer,
)
self.down_blocks.append(down_block)
# mid
if mid_block_type == "UNetMidBlock3DCrossAttn":
self.mid_block = UNetMidBlock3DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and motion_module_mid_block,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
custom_audio_layer=custom_audio_layer,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the videos
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
res = 2 ** (3 - i)
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and (res in motion_module_resolutions),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
custom_audio_layer=custom_audio_layer,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if use_inflated_groupnorm:
self.conv_norm_out = InflatedGroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
else:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
# support controlnet
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# time
timesteps = timestep
if not torch.is_tensor(timesteps):
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# pre-process
sample = self.conv_in(sample)
# down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
)
down_block_res_samples += res_samples
# support controlnet
down_block_res_samples = list(down_block_res_samples)
if down_block_additional_residuals is not None:
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
if down_block_additional_residual.dim() == 4: # boardcast
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
# mid
sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
# support controlnet
if mid_block_additional_residual is not None:
if mid_block_additional_residual.dim() == 4: # boardcast
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
sample = sample + mid_block_additional_residual
# up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
encoder_hidden_states=encoder_hidden_states,
)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet3DConditionOutput(sample=sample)
def load_state_dict(self, state_dict, strict=True):
# If the loaded checkpoint's in_channels or out_channels are different from config
temp_state_dict = copy.deepcopy(state_dict)
if temp_state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
del temp_state_dict["conv_in.weight"]
del temp_state_dict["conv_in.bias"]
if temp_state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
del temp_state_dict["conv_out.weight"]
del temp_state_dict["conv_out.bias"]
# If the loaded checkpoint's cross_attention_dim is different from config
keys_to_remove = []
for key in temp_state_dict:
if "audio_cross_attn.attn.to_k." in key or "audio_cross_attn.attn.to_v." in key:
if temp_state_dict[key].shape[1] != self.config.cross_attention_dim:
keys_to_remove.append(key)
for key in keys_to_remove:
del temp_state_dict[key]
return super().load_state_dict(state_dict=temp_state_dict, strict=strict)
@classmethod
def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
unet = cls.from_config(model_config).to(device)
if ckpt_path != "":
zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
if "global_step" in ckpt:
zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
resume_global_step = ckpt["global_step"]
else:
resume_global_step = 0
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
unet.load_state_dict(state_dict, strict=False)
else:
resume_global_step = 0
return unet, resume_global_step
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
import torch
from torch import nn
from .attention import Transformer3DModel
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
from .motion_module import get_motion_module
def get_down_block(
down_block_type,
num_layers,
in_channels,
out_channels,
temb_channels,
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
audio_condition_method="cross_attn",
custom_audio_layer=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock3D":
return DownBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
elif down_block_type == "CrossAttnDownBlock3D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
return CrossAttnDownBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
custom_audio_layer=custom_audio_layer,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type,
num_layers,
in_channels,
out_channels,
prev_output_channel,
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
audio_condition_method="cross_attn",
custom_audio_layer=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock3D":
return UpBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
return CrossAttnUpBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
custom_audio_layer=custom_audio_layer,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlock3DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
audio_condition_method="cross_attn",
custom_audio_layer: bool = False,
):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
]
attentions = []
audio_attentions = []
motion_modules = []
for _ in range(num_layers):
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
use_motion_module=use_motion_module,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
)
)
audio_attentions.append(
Transformer3DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
use_motion_module=use_motion_module,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
custom_audio_layer=True,
)
if custom_audio_layer
else None
)
motion_modules.append(
get_motion_module(
in_channels=in_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
self.attentions = nn.ModuleList(attentions)
self.audio_attentions = nn.ModuleList(audio_attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, audio_attn, resnet, motion_module in zip(
self.attentions, self.audio_attentions, self.resnets[1:], self.motion_modules
):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
hidden_states = (
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if audio_attn is not None
else hidden_states
)
hidden_states = (
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if motion_module is not None
else hidden_states
)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class CrossAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
audio_condition_method="cross_attn",
custom_audio_layer: bool = False,
):
super().__init__()
resnets = []
attentions = []
audio_attentions = []
motion_modules = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_motion_module=use_motion_module,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
)
)
audio_attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_motion_module=use_motion_module,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
custom_audio_layer=True,
)
if custom_audio_layer
else None
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.attentions = nn.ModuleList(attentions)
self.audio_attentions = nn.ModuleList(audio_attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample3D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
output_states = ()
for resnet, attn, audio_attn, motion_module in zip(
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
)[0]
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
encoder_hidden_states,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
hidden_states = (
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if audio_attn is not None
else hidden_states
)
# add motion module
hidden_states = (
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if motion_module is not None
else hidden_states
)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class DownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
):
super().__init__()
resnets = []
motion_modules = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample3D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, motion_module in zip(self.resnets, self.motion_modules):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
encoder_hidden_states,
)
else:
hidden_states = resnet(hidden_states, temb)
# add motion module
hidden_states = (
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if motion_module is not None
else hidden_states
)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class CrossAttnUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
audio_condition_method="cross_attn",
custom_audio_layer=False,
):
super().__init__()
resnets = []
attentions = []
audio_attentions = []
motion_modules = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_motion_module=use_motion_module,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
)
)
audio_attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
use_motion_module=use_motion_module,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
add_audio_layer=add_audio_layer,
audio_condition_method=audio_condition_method,
custom_audio_layer=True,
)
if custom_audio_layer
else None
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.attentions = nn.ModuleList(attentions)
self.audio_attentions = nn.ModuleList(audio_attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
attention_mask=None,
):
for resnet, attn, audio_attn, motion_module in zip(
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
)[0]
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
encoder_hidden_states,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
hidden_states = (
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if audio_attn is not None
else hidden_states
)
# add motion module
hidden_states = (
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if motion_module is not None
else hidden_states
)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class UpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
):
super().__init__()
resnets = []
motion_modules = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
upsample_size=None,
encoder_hidden_states=None,
):
for resnet, motion_module in zip(self.resnets, self.motion_modules):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
encoder_hidden_states,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = (
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if motion_module is not None
else hidden_states
)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
import inspect
import os
import shutil
from typing import Callable, List, Optional, Union
import subprocess
import numpy as np
import torch
import torchvision
from diffusers.utils import is_accelerate_available
from packaging import version
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import deprecate, logging
from einops import rearrange
from ..models.unet import UNet3DConditionModel
from ..utils.image_processor import ImageProcessor
from ..utils.util import read_video, read_audio, write_video
from ..whisper.audio2feature import Audio2Feature
import tqdm
import soundfile as sf
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LipsyncPipeline(DiffusionPipeline):
_optional_components = []
def __init__(
self,
vae: AutoencoderKL,
audio_encoder: Audio2Feature,
unet: UNet3DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.set_progress_bar_config(desc="Steps")
def enable_vae_slicing(self):
self.vae.enable_slicing()
def disable_vae_slicing(self):
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0):
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def decode_latents(self, latents):
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
latents = rearrange(latents, "b c f h w -> (b f) c h w")
decoded_latents = self.vae.decode(latents).sample
return decoded_latents
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, height, width, callback_steps):
assert height == width, "Height and width must be equal"
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_frames, num_channels_latents, height, width, dtype, device, generator):
shape = (
batch_size,
num_channels_latents,
1,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
rand_device = "cpu" if device.type == "mps" else device
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = latents.repeat(1, 1, num_frames, 1, 1)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_latents(
self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
masked_image = masked_image.to(device=device, dtype=dtype)
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
mask = mask.to(device=device, dtype=dtype)
# assume batch size = 1
mask = rearrange(mask, "f c h w -> 1 c f h w")
masked_image_latents = rearrange(masked_image_latents, "f c h w -> 1 c f h w")
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
return mask, masked_image_latents
def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
images = images.to(device=device, dtype=dtype)
image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
return image_latents
def set_progress_bar_config(self, **kwargs):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
self._progress_bar_config.update(kwargs)
@staticmethod
def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
# Paste the surrounding pixels back, because we only want to change the mouth region
pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
masks = masks.to(device=device, dtype=weight_dtype)
combined_pixel_values = decoded_latents * masks + pixel_values * (1 - masks)
return combined_pixel_values
@staticmethod
def pixel_values_to_images(pixel_values: torch.Tensor):
pixel_values = rearrange(pixel_values, "f c h w -> f h w c")
pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
images = (pixel_values * 255).to(torch.uint8)
images = images.cpu().numpy()
return images
def affine_transform_video(self, video_path):
video_frames = read_video(video_path, use_decord=False)
faces = []
boxes = []
affine_matrices = []
print(f"Affine transforming {len(video_frames)} faces...")
for frame in tqdm.tqdm(video_frames):
face, box, affine_matrix = self.image_processor.affine_transform(frame)
faces.append(face)
boxes.append(box)
affine_matrices.append(affine_matrix)
faces = torch.stack(faces)
return faces, video_frames, boxes, affine_matrices
def restore_video(self, faces, video_frames, boxes, affine_matrices):
video_frames = video_frames[: faces.shape[0]]
out_frames = []
for index, face in enumerate(faces):
x1, y1, x2, y2 = boxes[index]
height = int(y2 - y1)
width = int(x2 - x1)
face = torchvision.transforms.functional.resize(face, size=(height, width), antialias=True)
face = rearrange(face, "c h w -> h w c")
face = (face / 2 + 0.5).clamp(0, 1)
face = (face * 255).to(torch.uint8).cpu().numpy()
out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
out_frames.append(out_frame)
return np.stack(out_frames, axis=0)
@torch.no_grad()
def __call__(
self,
video_path: str,
audio_path: str,
video_out_path: str,
video_mask_path: str = None,
num_frames: int = 16,
video_fps: int = 25,
audio_sample_rate: int = 16000,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 20,
guidance_scale: float = 1.5,
weight_dtype: Optional[torch.dtype] = torch.float16,
eta: float = 0.0,
mask: str = "fix_mask",
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
is_train = self.unet.training
self.unet.eval()
# 0. Define call parameters
batch_size = 1
device = self._execution_device
self.image_processor = ImageProcessor(height, mask=mask, device="cuda")
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
audio_samples = read_audio(audio_path)
# 1. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 2. Check inputs
self.check_inputs(height, width, callback_steps)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 4. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
self.video_fps = video_fps
if self.unet.add_audio_layer:
whisper_feature = self.audio_encoder.audio2feat(audio_path)
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
num_inferences = min(len(video_frames), len(whisper_chunks)) // num_frames
else:
num_inferences = len(video_frames) // num_frames
synced_video_frames = []
masked_video_frames = []
num_channels_latents = self.vae.config.latent_channels
# Prepare latent variables
all_latents = self.prepare_latents(
batch_size,
num_frames * num_inferences,
num_channels_latents,
height,
width,
weight_dtype,
device,
generator,
)
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
if self.unet.add_audio_layer:
audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
if do_classifier_free_guidance:
empty_audio_embeds = torch.zeros_like(audio_embeds)
audio_embeds = torch.cat([empty_audio_embeds, audio_embeds])
else:
audio_embeds = None
inference_video_frames = video_frames[i * num_frames : (i + 1) * num_frames]
latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
inference_video_frames, affine_transform=False
)
# 7. Prepare mask latent variables
mask_latents, masked_image_latents = self.prepare_mask_latents(
masks,
masked_pixel_values,
height,
width,
weight_dtype,
device,
generator,
do_classifier_free_guidance,
)
# 8. Prepare image latents
image_latents = self.prepare_image_latents(
pixel_values,
device,
weight_dtype,
generator,
do_classifier_free_guidance,
)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for j, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat(
[latent_model_input, mask_latents, masked_image_latents, image_latents], dim=1
)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=audio_embeds).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and j % callback_steps == 0:
callback(j, t, latents)
# Recover the pixel values
decoded_latents = self.decode_latents(latents)
decoded_latents = self.paste_surrounding_pixels_back(
decoded_latents, pixel_values, 1 - masks, device, weight_dtype
)
synced_video_frames.append(decoded_latents)
masked_video_frames.append(masked_pixel_values)
synced_video_frames = self.restore_video(
torch.cat(synced_video_frames), original_video_frames, boxes, affine_matrices
)
masked_video_frames = self.restore_video(
torch.cat(masked_video_frames), original_video_frames, boxes, affine_matrices
)
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
if is_train:
self.unet.train()
temp_dir = "temp"
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=25)
# write_video(video_mask_path, masked_video_frames, fps=25)
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
command = f"ffmpeg -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
subprocess.run(command, shell=True)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import torch.nn as nn
from einops import rearrange
from .third_party.VideoMAEv2.utils import load_videomae_model
class TREPALoss:
def __init__(
self,
device="cuda",
ckpt_path="/mnt/bn/maliva-gen-ai-v2/chunyu.li/checkpoints/vit_g_hybrid_pt_1200e_ssv2_ft.pth",
):
self.model = load_videomae_model(device, ckpt_path).eval().to(dtype=torch.float16)
self.model.requires_grad_(False)
self.bce_loss = nn.BCELoss()
def __call__(self, videos_fake, videos_real, loss_type="mse"):
batch_size = videos_fake.shape[0]
num_frames = videos_fake.shape[2]
videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w")
videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w")
videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bilinear")
videos_real = F.interpolate(videos_real, size=(224, 224), mode="bilinear")
videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames)
videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames)
# Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1)
videos_real = (videos_real / 2 + 0.5).clamp(0, 1)
feats_fake = self.model.forward_features(videos_fake)
feats_real = self.model.forward_features(videos_real)
feats_fake = F.normalize(feats_fake, p=2, dim=1)
feats_real = F.normalize(feats_real, p=2, dim=1)
return F.mse_loss(feats_fake, feats_real)
if __name__ == "__main__":
# input shape: (b, c, f, h, w)
videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
trepa_loss = TREPALoss(device="cuda")
loss = trepa_loss(videos_fake, videos_real)
print(loss)
import os
import torch
import requests
from tqdm import tqdm
from torchvision import transforms
from .videomaev2_finetune import vit_giant_patch14_224
def to_normalized_float_tensor(vid):
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
# NOTE: for those functions, which generally expect mini-batches, we keep them
# as non-minibatch so that they are applied as if they were 4d (thus image).
# this way, we only apply the transformation in the spatial domain
def resize(vid, size, interpolation='bilinear'):
# NOTE: using bilinear interpolation because we don't work on minibatches
# at this level
scale = None
if isinstance(size, int):
scale = float(size) / min(vid.shape[-2:])
size = None
return torch.nn.functional.interpolate(
vid,
size=size,
scale_factor=scale,
mode=interpolation,
align_corners=False)
class ToFloatTensorInZeroOne(object):
def __call__(self, vid):
return to_normalized_float_tensor(vid)
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return resize(vid, self.size)
def preprocess_videomae(videos):
transform = transforms.Compose(
[ToFloatTensorInZeroOne(),
Resize((224, 224))])
return torch.stack([transform(f) for f in torch.from_numpy(videos)])
def load_videomae_model(device, ckpt_path=None):
if ckpt_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
ckpt_path = os.path.join(current_dir, 'vit_g_hybrid_pt_1200e_ssv2_ft.pth')
if not os.path.exists(ckpt_path):
# download the ckpt to the path
ckpt_url = 'https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth'
response = requests.get(ckpt_url, stream=True, allow_redirects=True)
total_size = int(response.headers.get("content-length", 0))
block_size = 1024
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
with open(ckpt_path, "wb") as fw:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
fw.write(data)
model = vit_giant_patch14_224(
img_size=224,
pretrained=False,
num_classes=174,
all_frames=16,
tubelet_size=2,
drop_path_rate=0.3,
use_mean_pooling=True)
ckpt = torch.load(ckpt_path, map_location='cpu')
for model_key in ['model', 'module']:
if model_key in ckpt:
ckpt = ckpt[model_key]
break
model.load_state_dict(ckpt)
return model.to(device)
\ No newline at end of file
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
from functools import partial
import math
import warnings
import numpy as np
import collections.abc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from itertools import repeat
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""
Adapted from timm codebase
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 400,
"input_size": (3, 224, 224),
"pool_size": None,
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": (0.5, 0.5, 0.5),
"std": (0.5, 0.5, 0.5),
**kwargs,
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class CosAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# self.scale = qk_scale or head_dim**-0.5
# DO NOT RENAME [self.scale] (for no weight decay)
if qk_scale is None:
self.scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
else:
self.scale = qk_scale
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
# torch.log(torch.tensor(1. / 0.01)) = 4.6052
logit_scale = torch.clamp(self.scale, max=4.6052).exp()
attn = attn * logit_scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_head_dim=None,
cos_attn=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
if cos_attn:
self.attn = CosAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
)
else:
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_spatial_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
num_patches = num_spatial_patches * (num_frames // tubelet_size)
self.img_size = img_size
self.tubelet_size = tubelet_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv3d(
in_channels=in_chans,
out_channels=embed_dim,
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
stride=(self.tubelet_size, patch_size[0], patch_size[1]),
)
def forward(self, x, **kwargs):
B, C, T, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# b, c, l -> b, l, c
# [1, 1408, 8, 16, 16] -> [1, 1408, 2048] -> [1, 2048, 1408]
x = self.proj(x).flatten(2).transpose(1, 2)
return x
# sin-cos position encoding
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def get_sinusoid_encoding_table(n_position, d_hid):
"""Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
class VisionTransformer(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
head_drop_rate=0.0,
norm_layer=nn.LayerNorm,
init_values=0.0,
use_learnable_pos_emb=False,
init_scale=0.0,
all_frames=16,
tubelet_size=2,
use_mean_pooling=True,
with_cp=False,
cos_attn=False,
):
super().__init__()
self.num_classes = num_classes
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.tubelet_size = tubelet_size
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
num_frames=all_frames,
tubelet_size=tubelet_size,
)
num_patches = self.patch_embed.num_patches
self.with_cp = with_cp
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
else:
# sine-cosine positional embeddings is on the way
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn,
)
for i in range(depth)
]
)
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head_dropout = nn.Dropout(head_drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
self.num_frames = all_frames
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def interpolate_pos_encoding(self, t):
T = 8
t0 = t // self.tubelet_size
if T == t0:
return self.pos_embed
dim = self.pos_embed.shape[-1]
patch_pos_embed = self.pos_embed.permute(0, 2, 1).reshape(1, dim, 8, 16, 16)
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
t0 = t0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(t0 / T, 1, 1),
mode="trilinear",
)
assert int(t0) == patch_pos_embed.shape[-3]
patch_pos_embed = patch_pos_embed.reshape(1, dim, -1).permute(0, 2, 1)
return patch_pos_embed
def forward_features(self, x):
# [1, 3, 16, 224, 224]
B = x.size(0)
T = x.size(2)
# [1, 2048, 1408]
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.interpolate_pos_encoding(T).expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
x = self.pos_drop(x)
for blk in self.blocks:
if self.with_cp:
x = cp.checkpoint(blk, x)
else:
x = blk(x)
# return self.fc_norm(x)
if self.fc_norm is not None:
return self.fc_norm(x.mean(1))
else:
return self.norm(x[:, 0])
def forward(self, x):
x = self.forward_features(x)
x = self.head_dropout(x)
x = self.head(x)
return x
def vit_giant_patch14_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=14,
embed_dim=1408,
depth=40,
num_heads=16,
mlp_ratio=48 / 11,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
model.default_cfg = _cfg()
return model
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from .videomaev2_finetune import (
Block,
PatchEmbed,
_cfg,
get_sinusoid_encoding_table,
)
from .videomaev2_finetune import trunc_normal_ as __call_trunc_normal_
def trunc_normal_(tensor, mean=0., std=1.):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
class PretrainVisionTransformerEncoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=0,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=None,
tubelet_size=2,
use_learnable_pos_emb=False,
with_cp=False,
all_frames=16,
cos_attn=False):
super().__init__()
self.num_classes = num_classes
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
num_frames=all_frames,
tubelet_size=tubelet_size)
num_patches = self.patch_embed.num_patches
self.with_cp = with_cp
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim))
else:
# sine-cosine positional embeddings
self.pos_embed = get_sinusoid_encoding_table(
num_patches, embed_dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, mask):
x = self.patch_embed(x)
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
B, _, C = x.shape
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
for blk in self.blocks:
if self.with_cp:
x_vis = cp.checkpoint(blk, x_vis)
else:
x_vis = blk(x_vis)
x_vis = self.norm(x_vis)
return x_vis
def forward(self, x, mask):
x = self.forward_features(x, mask)
x = self.head(x)
return x
class PretrainVisionTransformerDecoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
patch_size=16,
num_classes=768,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=None,
num_patches=196,
tubelet_size=2,
with_cp=False,
cos_attn=False):
super().__init__()
self.num_classes = num_classes
assert num_classes == 3 * tubelet_size * patch_size**2
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.patch_size = patch_size
self.with_cp = with_cp
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, return_token_num):
for blk in self.blocks:
if self.with_cp:
x = cp.checkpoint(blk, x)
else:
x = blk(x)
if return_token_num > 0:
# only return the mask tokens predict pixels
x = self.head(self.norm(x[:, -return_token_num:]))
else:
# [B, N, 3*16^2]
x = self.head(self.norm(x))
return x
class PretrainVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
encoder_in_chans=3,
encoder_num_classes=0,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
decoder_num_classes=1536, # decoder_num_classes=768
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=8,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
use_learnable_pos_emb=False,
tubelet_size=2,
num_classes=0, # avoid the error from create_fn in timm
in_chans=0, # avoid the error from create_fn in timm
with_cp=False,
all_frames=16,
cos_attn=False,
):
super().__init__()
self.encoder = PretrainVisionTransformerEncoder(
img_size=img_size,
patch_size=patch_size,
in_chans=encoder_in_chans,
num_classes=encoder_num_classes,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
use_learnable_pos_emb=use_learnable_pos_emb,
with_cp=with_cp,
all_frames=all_frames,
cos_attn=cos_attn)
self.decoder = PretrainVisionTransformerDecoder(
patch_size=patch_size,
num_patches=self.encoder.patch_embed.num_patches,
num_classes=decoder_num_classes,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
with_cp=with_cp,
cos_attn=cos_attn)
self.encoder_to_decoder = nn.Linear(
encoder_embed_dim, decoder_embed_dim, bias=False)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.pos_embed = get_sinusoid_encoding_table(
self.encoder.patch_embed.num_patches, decoder_embed_dim)
trunc_normal_(self.mask_token, std=.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'mask_token'}
def forward(self, x, mask, decode_mask=None):
decode_vis = mask if decode_mask is None else ~decode_mask
x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
B, N_vis, C = x_vis.shape
# we don't unshuffle the correct visible token order,
# but shuffle the pos embedding accorddingly.
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(
x.device).clone().detach()
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
pos_emd_mask = expand_pos_embed[decode_vis].reshape(B, -1, C)
# [B, N, C_d]
x_full = torch.cat(
[x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
# NOTE: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
x = self.decoder(x_full, pos_emd_mask.shape[1])
return x
def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=384,
encoder_depth=12,
encoder_num_heads=6,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=192,
decoder_num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=384,
decoder_num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_giant_patch14_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=14,
encoder_embed_dim=1408,
encoder_depth=40,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1176, # 14 * 14 * 3 * 2,
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=48 / 11,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
import os
import math
import os.path as osp
import random
import pickle
import warnings
import glob
import numpy as np
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn.functional as F
import torch.distributed as dist
from torchvision.datasets.video_utils import VideoClips
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
VID_EXTENSIONS = ['.avi', '.mp4', '.webm', '.mov', '.mkv', '.m4v']
def get_dataloader(data_path, image_folder, resolution=128, sequence_length=16, sample_every_n_frames=1,
batch_size=16, num_workers=8):
data = VideoData(data_path, image_folder, resolution, sequence_length, sample_every_n_frames, batch_size, num_workers)
loader = data._dataloader()
return loader
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def get_parent_dir(path):
return osp.basename(osp.dirname(path))
def preprocess(video, resolution, sequence_length=None, in_channels=3, sample_every_n_frames=1):
# video: THWC, {0, ..., 255}
assert in_channels == 3
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
t, c, h, w = video.shape
# temporal crop
if sequence_length is not None:
assert sequence_length <= t
video = video[:sequence_length]
# skip frames
if sample_every_n_frames > 1:
video = video[::sample_every_n_frames]
# scale shorter side to resolution
scale = resolution / min(h, w)
if h < w:
target_size = (resolution, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), resolution)
video = F.interpolate(video, size=target_size, mode='bilinear',
align_corners=False, antialias=True)
# center crop
t, c, h, w = video.shape
w_start = (w - resolution) // 2
h_start = (h - resolution) // 2
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
return {'video': video}
def preprocess_image(image):
# [0, 1] => [-1, 1]
img = torch.from_numpy(image)
return img
class VideoData(data.Dataset):
""" Class to create dataloaders for video datasets
Args:
data_path: Path to the folder with video frames or videos.
image_folder: If True, the data is stored as images in folders.
resolution: Resolution of the returned videos.
sequence_length: Length of extracted video sequences.
sample_every_n_frames: Sample every n frames from the video.
batch_size: Batch size.
num_workers: Number of workers for the dataloader.
shuffle: If True, shuffle the data.
"""
def __init__(self, data_path: str, image_folder: bool, resolution: int, sequence_length: int,
sample_every_n_frames: int, batch_size: int, num_workers: int, shuffle: bool = True):
super().__init__()
self.data_path = data_path
self.image_folder = image_folder
self.resolution = resolution
self.sequence_length = sequence_length
self.sample_every_n_frames = sample_every_n_frames
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
def _dataset(self):
'''
Initializes and return the dataset.
'''
if self.image_folder:
Dataset = FrameDataset
dataset = Dataset(self.data_path, self.sequence_length,
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
else:
Dataset = VideoDataset
dataset = Dataset(self.data_path, self.sequence_length,
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
return dataset
def _dataloader(self):
'''
Initializes and returns the dataloader.
'''
dataset = self._dataset()
if dist.is_initialized():
sampler = data.distributed.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
)
else:
sampler = None
dataloader = data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
sampler=sampler,
shuffle=sampler is None and self.shuffle is True
)
return dataloader
class VideoDataset(data.Dataset):
"""
Generic dataset for videos files stored in folders.
Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory.
The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos.
Returns BCTHW videos in the range [0, 1].
Args:
data_folder: Path to the folder with corresponding videos stored.
sequence_length: Length of extracted video sequences.
resolution: Resolution of the returned videos.
sample_every_n_frames: Sample every n frames from the video.
"""
def __init__(self, data_folder: str, sequence_length: int = 16, resolution: int = 128, sample_every_n_frames: int = 1):
super().__init__()
self.sequence_length = sequence_length
self.resolution = resolution
self.sample_every_n_frames = sample_every_n_frames
folder = data_folder
files = sum([glob.glob(osp.join(folder, '**', f'*{ext}'), recursive=True)
for ext in VID_EXTENSIONS], [])
warnings.filterwarnings('ignore')
cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl")
if not osp.exists(cache_file):
clips = VideoClips(files, sequence_length, num_workers=4)
try:
pickle.dump(clips.metadata, open(cache_file, 'wb'))
except:
print(f"Failed to save metadata to {cache_file}")
else:
metadata = pickle.load(open(cache_file, 'rb'))
clips = VideoClips(files, sequence_length,
_precomputed_metadata=metadata)
self._clips = clips
# instead of uniformly sampling from all possible clips, we sample uniformly from all possible videos
self._clips.get_clip_location = self.get_random_clip_from_video
def get_random_clip_from_video(self, idx: int) -> tuple:
'''
Sample a random clip starting index from the video.
Args:
idx: Index of the video.
'''
# Note that some videos may not contain enough frames, we skip those videos here.
while self._clips.clips[idx].shape[0] <= 0:
idx += 1
n_clip = self._clips.clips[idx].shape[0]
clip_id = random.randint(0, n_clip - 1)
return idx, clip_id
def __len__(self):
return self._clips.num_videos()
def __getitem__(self, idx):
resolution = self.resolution
while True:
try:
video, _, _, idx = self._clips.get_clip(idx)
except Exception as e:
print(idx, e)
idx = (idx + 1) % self._clips.num_clips()
continue
break
return dict(**preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames))
class FrameDataset(data.Dataset):
"""
Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders
in the provided directory. Each leaf folder is assumed to contain frames from a single video.
Args:
data_folder: path to the folder with video frames. The folder
should contain folders with frames from each video.
sequence_length: length of extracted video sequences
resolution: resolution of the returned videos
sample_every_n_frames: sample every n frames from the video
"""
def __init__(self, data_folder, sequence_length, resolution=64, sample_every_n_frames=1):
self.resolution = resolution
self.sequence_length = sequence_length
self.sample_every_n_frames = sample_every_n_frames
self.data_all = self.load_video_frames(data_folder)
self.video_num = len(self.data_all)
def __getitem__(self, index):
batch_data = self.getTensor(index)
return_list = {'video': batch_data}
return return_list
def load_video_frames(self, dataroot: str) -> list:
'''
Loads all the video frames under the dataroot and returns a list of all the video frames.
Args:
dataroot: The root directory containing the video frames.
Returns:
A list of all the video frames.
'''
data_all = []
frame_list = os.walk(dataroot)
for _, meta in enumerate(frame_list):
root = meta[0]
try:
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
except:
print(meta[0], meta[2])
if len(frames) < max(0, self.sequence_length * self.sample_every_n_frames):
continue
frames = [
os.path.join(root, item) for item in frames
if is_image_file(item)
]
if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
data_all.append(frames)
return data_all
def getTensor(self, index: int) -> torch.Tensor:
'''
Returns a tensor of the video frames at the given index.
Args:
index: The index of the video frames to return.
Returns:
A BCTHW tensor in the range `[0, 1]` of the video frames at the given index.
'''
video = self.data_all[index]
video_len = len(video)
# load the entire video when sequence_length = -1, whiel the sample_every_n_frames has to be 1
if self.sequence_length == -1:
assert self.sample_every_n_frames == 1
start_idx = 0
end_idx = video_len
else:
n_frames_interval = self.sequence_length * self.sample_every_n_frames
start_idx = random.randint(0, video_len - n_frames_interval)
end_idx = start_idx + n_frames_interval
img = Image.open(video[0])
h, w = img.height, img.width
if h > w:
half = (h - w) // 2
cropsize = (0, half, w, half + w) # left, upper, right, lower
elif w > h:
half = (w - h) // 2
cropsize = (half, 0, half + h, h)
images = []
for i in range(start_idx, end_idx,
self.sample_every_n_frames):
path = video[i]
img = Image.open(path)
if h != w:
img = img.crop(cropsize)
img = img.resize(
(self.resolution, self.resolution),
Image.ANTIALIAS)
img = np.asarray(img, dtype=np.float32)
img /= 255.
img_tensor = preprocess_image(img).unsqueeze(0)
images.append(img_tensor)
video_clip = torch.cat(images).permute(3, 0, 1, 2)
return video_clip
def __len__(self):
return self.video_num
# Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
import os
import random
import torch
import pickle
import numpy as np
from typing import List, Tuple
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
class FeatureStats:
'''
Class to store statistics of features, including all features and mean/covariance.
Args:
capture_all: Whether to store all the features.
capture_mean_cov: Whether to store mean and covariance.
max_items: Maximum number of items to store.
'''
def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
'''
'''
self.capture_all = capture_all
self.capture_mean_cov = capture_mean_cov
self.max_items = max_items
self.num_items = 0
self.num_features = None
self.all_features = None
self.raw_mean = None
self.raw_cov = None
def set_num_features(self, num_features: int):
'''
Set the number of features diminsions.
Args:
num_features: Number of features diminsions.
'''
if self.num_features is not None:
assert num_features == self.num_features
else:
self.num_features = num_features
self.all_features = []
self.raw_mean = np.zeros([num_features], dtype=np.float64)
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
def is_full(self) -> bool:
'''
Check if the maximum number of samples is reached.
Returns:
True if the storage is full, False otherwise.
'''
return (self.max_items is not None) and (self.num_items >= self.max_items)
def append(self, x: np.ndarray):
'''
Add the newly computed features to the list. Update the mean and covariance.
Args:
x: New features to record.
'''
x = np.asarray(x, dtype=np.float32)
assert x.ndim == 2
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
if self.num_items >= self.max_items:
return
x = x[:self.max_items - self.num_items]
self.set_num_features(x.shape[1])
self.num_items += x.shape[0]
if self.capture_all:
self.all_features.append(x)
if self.capture_mean_cov:
x64 = x.astype(np.float64)
self.raw_mean += x64.sum(axis=0)
self.raw_cov += x64.T @ x64
def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
'''
Add the newly computed PyTorch features to the list. Update the mean and covariance.
Args:
x: New features to record.
rank: Rank of the current GPU.
num_gpus: Total number of GPUs.
'''
assert isinstance(x, torch.Tensor) and x.ndim == 2
assert 0 <= rank < num_gpus
if num_gpus > 1:
ys = []
for src in range(num_gpus):
y = x.clone()
torch.distributed.broadcast(y, src=src)
ys.append(y)
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
self.append(x.cpu().numpy())
def get_all(self) -> np.ndarray:
'''
Get all the stored features as NumPy Array.
Returns:
Concatenation of the stored features.
'''
assert self.capture_all
return np.concatenate(self.all_features, axis=0)
def get_all_torch(self) -> torch.Tensor:
'''
Get all the stored features as PyTorch Tensor.
Returns:
Concatenation of the stored features.
'''
return torch.from_numpy(self.get_all())
def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
'''
Get the mean and covariance of the stored features.
Returns:
Mean and covariance of the stored features.
'''
assert self.capture_mean_cov
mean = self.raw_mean / self.num_items
cov = self.raw_cov / self.num_items
cov = cov - np.outer(mean, mean)
return mean, cov
def save(self, pkl_file: str):
'''
Save the features and statistics to a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with open(pkl_file, 'wb') as f:
pickle.dump(self.__dict__, f)
@staticmethod
def load(pkl_file: str) -> 'FeatureStats':
'''
Load the features and statistics from a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with open(pkl_file, 'rb') as f:
s = pickle.load(f)
obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
obj.__dict__.update(s)
print('Loaded %d features from %s' % (obj.num_items, pkl_file))
return obj
# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
import numpy as np
import cv2
def transformation_from_points(points1, points0, smooth=True, p_bias=None):
points2 = np.array(points0)
points2 = points2.astype(np.float64)
points1 = points1.astype(np.float64)
c1 = np.mean(points1, axis=0)
c2 = np.mean(points2, axis=0)
points1 -= c1
points2 -= c2
s1 = np.std(points1)
s2 = np.std(points2)
points1 /= s1
points2 /= s2
U, S, Vt = np.linalg.svd(np.matmul(points1.T, points2))
R = (np.matmul(U, Vt)).T
sR = (s2 / s1) * R
T = c2.reshape(2, 1) - (s2 / s1) * np.matmul(R, c1.reshape(2, 1))
M = np.concatenate((sR, T), axis=1)
if smooth:
bias = points2[2] - points1[2]
if p_bias is None:
p_bias = bias
else:
bias = p_bias * 0.2 + bias * 0.8
p_bias = bias
M[:, 2] = M[:, 2] + bias
return M, p_bias
class AlignRestore(object):
def __init__(self, align_points=3):
if align_points == 3:
self.upscale_factor = 1
self.crop_ratio = (2.8, 2.8)
self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
self.face_template = self.face_template * 2.8
# self.face_size = (int(100 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
self.p_bias = None
def process(self, img, lmk_align=None, smooth=True, align_points=3):
aligned_face, affine_matrix = self.align_warp_face(img, lmk_align, smooth)
restored_img = self.restore_img(img, aligned_face, affine_matrix)
cv2.imwrite("restored.jpg", restored_img)
cv2.imwrite("aligned.jpg", aligned_face)
return aligned_face, restored_img
def align_warp_face(self, img, lmks3, smooth=True, border_mode="constant"):
affine_matrix, self.p_bias = transformation_from_points(lmks3, self.face_template, smooth, self.p_bias)
if border_mode == "constant":
border_mode = cv2.BORDER_CONSTANT
elif border_mode == "reflect101":
border_mode = cv2.BORDER_REFLECT101
elif border_mode == "reflect":
border_mode = cv2.BORDER_REFLECT
cropped_face = cv2.warpAffine(
img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=[127, 127, 127]
)
return cropped_face, affine_matrix
def align_warp_face2(self, img, landmark, border_mode="constant"):
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template)[0]
if border_mode == "constant":
border_mode = cv2.BORDER_CONSTANT
elif border_mode == "reflect101":
border_mode = cv2.BORDER_REFLECT101
elif border_mode == "reflect":
border_mode = cv2.BORDER_REFLECT
cropped_face = cv2.warpAffine(
img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)
)
return cropped_face, affine_matrix
def restore_img(self, input_img, face, affine_matrix):
h, w, _ = input_img.shape
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
upsample_img = cv2.resize(input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
inverse_affine = cv2.invertAffineTransform(affine_matrix)
inverse_affine *= self.upscale_factor
if self.upscale_factor > 1:
extra_offset = 0.5 * self.upscale_factor
else:
extra_offset = 0
inverse_affine[:, 2] += extra_offset
inv_restored = cv2.warpAffine(face, inverse_affine, (w_up, h_up))
mask = np.ones((self.face_size[1], self.face_size[0]), dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
inv_mask_erosion = cv2.erode(
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)
)
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion)
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
inv_soft_mask = inv_soft_mask[:, :, None]
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
if np.max(upsample_img) > 256:
upsample_img = upsample_img.astype(np.uint16)
else:
upsample_img = upsample_img.astype(np.uint8)
return upsample_img
class laplacianSmooth:
def __init__(self, smoothAlpha=0.3):
self.smoothAlpha = smoothAlpha
self.pts_last = None
def smooth(self, pts_cur):
if self.pts_last is None:
self.pts_last = pts_cur.copy()
return pts_cur.copy()
x1 = min(pts_cur[:, 0])
x2 = max(pts_cur[:, 0])
y1 = min(pts_cur[:, 1])
y2 = max(pts_cur[:, 1])
width = x2 - x1
pts_update = []
for i in range(len(pts_cur)):
x_new, y_new = pts_cur[i]
x_old, y_old = self.pts_last[i]
tmp = (x_new - x_old) ** 2 + (y_new - y_old) ** 2
w = np.exp(-tmp / (width * self.smoothAlpha))
x = x_old * w + x_new * (1 - w)
y = y_old * w + y_new * (1 - w)
pts_update.append([x, y])
pts_update = np.array(pts_update)
self.pts_last = pts_update.copy()
return pts_update
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment