Commit 5c023842 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 增加LatentSync

parent 822b66ca
Pipeline #2211 canceled with stages
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from torch.utils.data import Dataset
import torch
import random
import cv2
from ..utils.image_processor import ImageProcessor, load_fixed_mask
from ..utils.audio import melspectrogram
from decord import AudioReader, VideoReader, cpu
class UNetDataset(Dataset):
def __init__(self, train_data_dir: str, config):
if config.data.train_fileslist != "":
with open(config.data.train_fileslist) as file:
self.video_paths = [line.rstrip() for line in file]
elif train_data_dir != "":
self.video_paths = []
for file in os.listdir(train_data_dir):
if file.endswith(".mp4"):
self.video_paths.append(os.path.join(train_data_dir, file))
else:
raise ValueError("data_dir and fileslist cannot be both empty")
self.resolution = config.data.resolution
self.num_frames = config.data.num_frames
if self.num_frames == 16:
self.mel_window_length = 52
elif self.num_frames == 5:
self.mel_window_length = 16
else:
raise NotImplementedError("Only support 16 and 5 frames now")
self.audio_sample_rate = config.data.audio_sample_rate
self.video_fps = config.data.video_fps
self.mask = config.data.mask
self.mask_image = load_fixed_mask(self.resolution)
self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
def __len__(self):
return len(self.video_paths)
def read_audio(self, video_path: str):
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
return torch.from_numpy(original_mel)
def crop_audio_window(self, original_mel, start_index):
start_idx = int(80.0 * (start_index / float(self.video_fps)))
end_idx = start_idx + self.mel_window_length
return original_mel[:, start_idx:end_idx].unsqueeze(0)
def get_frames(self, video_reader: VideoReader):
total_num_frames = len(video_reader)
start_idx = random.randint(self.num_frames // 2, total_num_frames - self.num_frames - self.num_frames // 2)
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
while True:
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
if wrong_start_idx > start_idx - self.num_frames and wrong_start_idx < start_idx + self.num_frames:
continue
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
break
frames = video_reader.get_batch(frames_index).asnumpy()
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
return frames, wrong_frames, start_idx
def worker_init_fn(self, worker_id):
# Initialize the face mesh object in each worker process,
# because the face mesh object cannot be called in subprocesses
self.worker_id = worker_id
setattr(
self,
f"image_processor_{worker_id}",
ImageProcessor(self.resolution, self.mask, mask_image=self.mask_image),
)
def __getitem__(self, idx):
image_processor = getattr(self, f"image_processor_{self.worker_id}")
while True:
try:
idx = random.randint(0, len(self) - 1)
# Get video file path
video_path = self.video_paths[idx]
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
if len(vr) < 3 * self.num_frames:
continue
continuous_frames, ref_frames, start_idx = self.get_frames(vr)
if self.load_audio_data:
mel_cache_path = os.path.join(
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
)
if os.path.isfile(mel_cache_path):
try:
original_mel = torch.load(mel_cache_path)
except Exception as e:
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
os.remove(mel_cache_path)
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
else:
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
mel = self.crop_audio_window(original_mel, start_idx)
if mel.shape[-1] != self.mel_window_length:
continue
else:
mel = []
gt, masked_gt, mask = image_processor.prepare_masks_and_masked_images(
continuous_frames, affine_transform=False
)
if self.mask == "fix_mask":
ref, _, _ = image_processor.prepare_masks_and_masked_images(ref_frames, affine_transform=False)
else:
ref = image_processor.process_images(ref_frames)
vr.seek(0) # avoid memory leak
break
except Exception as e: # Handle the exception of face not detcted
print(f"{type(e).__name__} - {e} - {video_path}")
if "vr" in locals():
vr.seek(0) # avoid memory leak
sample = dict(
gt=gt,
masked_gt=masked_gt,
ref=ref,
mel=mel,
mask=mask,
video_path=video_path,
start_idx=start_idx,
)
return sample
This diff is collapsed.
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
# Actually we don't use the motion module in the final version of LatentSync
# When we started the project, we used the codebase of AnimateDiff and tried motion module
# But the results are poor, and we decied to leave the code here for possible future usage
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import CrossAttention, FeedForward
from einops import rearrange, repeat
import math
from .utils import zero_module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(
in_channels=in_channels,
**motion_module_kwargs,
)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads=8,
num_transformer_block=2,
attention_block_types=("Temporal_Self", "Temporal_Self"),
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
temporal_attention_dim_div=1,
zero_initialize=True,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = (
attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
video_length=video_length,
)
+ hidden_states
)
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.0, max_len=24):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
class VersatileAttention(CrossAttention):
def __init__(
self,
attention_mode=None,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.pos_encoder = (
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
if (temporal_position_encoding and attention_mode == "Temporal")
else None
)
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
batch_size, sequence_length, _ = hidden_states.shape
if self.attention_mode == "Temporal":
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
encoder_hidden_states = (
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
if encoder_hidden_states is not None
else encoder_hidden_states
)
else:
raise NotImplementedError
# encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# if self.use_conv:
# if self.name == "conv":
# hidden_states = self.conv(hidden_states)
# else:
# hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.conv(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
raise NotImplementedError
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
assert use_inflated_groupnorm != None
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
time_emb_proj_out_channels = out_channels
# if self.time_embedding_norm == "default":
# time_emb_proj_out_channels = out_channels
# elif self.time_embedding_norm == "scale_shift":
# time_emb_proj_out_channels = out_channels * 2
# else:
# raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
if self.time_embedding_norm == "scale_shift":
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
else:
self.double_len_linear = None
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
if temb.dim() == 2:
# input (1, 1280)
temb = self.time_emb_proj(self.nonlinearity(temb))
temb = temb[:, :, None, None, None] # unsqueeze
else:
# input (1, 1280, 16)
temb = temb.permute(0, 2, 1)
temb = self.time_emb_proj(self.nonlinearity(temb))
if self.double_len_linear is not None:
temb = self.double_len_linear(self.nonlinearity(temb))
temb = temb.permute(0, 2, 1)
temb = temb[:, :, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from einops import rearrange
from torch.nn import functional as F
from ..utils.util import cosine_loss
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import CrossAttention, FeedForward
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
class SyncNet(nn.Module):
def __init__(self, config):
super().__init__()
self.audio_encoder = DownEncoder2D(
in_channels=config["audio_encoder"]["in_channels"],
block_out_channels=config["audio_encoder"]["block_out_channels"],
downsample_factors=config["audio_encoder"]["downsample_factors"],
dropout=config["audio_encoder"]["dropout"],
attn_blocks=config["audio_encoder"]["attn_blocks"],
)
self.visual_encoder = DownEncoder2D(
in_channels=config["visual_encoder"]["in_channels"],
block_out_channels=config["visual_encoder"]["block_out_channels"],
downsample_factors=config["visual_encoder"]["downsample_factors"],
dropout=config["visual_encoder"]["dropout"],
attn_blocks=config["visual_encoder"]["attn_blocks"],
)
self.eval()
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
norm_num_groups: int = 32,
eps: float = 1e-6,
act_fn: str = "silu",
downsample_factor=2,
):
super().__init__()
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
if isinstance(downsample_factor, list):
downsample_factor = tuple(downsample_factor)
if downsample_factor == 1:
self.downsample_conv = None
else:
self.downsample_conv = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
)
self.pad = (0, 1, 0, 1)
if isinstance(downsample_factor, tuple):
if downsample_factor[0] == 1:
self.pad = (0, 1, 1, 1) # The padding order is from back to front
elif downsample_factor[1] == 1:
self.pad = (1, 1, 0, 1)
def forward(self, input_tensor):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
hidden_states += input_tensor
if self.downsample_conv is not None:
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
hidden_states = self.downsample_conv(hidden_states)
return hidden_states
class AttentionBlock2D(nn.Module):
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
super().__init__()
if not is_xformers_available():
raise ModuleNotFoundError(
"You have to install xformers to enable memory efficient attetion", name="xformers"
)
# inner_dim = dim_head * heads
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
self.norm2 = nn.LayerNorm(query_dim)
self.norm3 = nn.LayerNorm(query_dim)
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
self.attn._use_memory_efficient_attention_xformers = True
def forward(self, hidden_states):
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.conv_in(hidden_states)
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DownEncoder2D(nn.Module):
def __init__(
self,
in_channels=4 * 16,
block_out_channels=[64, 128, 256, 256],
downsample_factors=[2, 2, 2, 2],
layers_per_block=2,
norm_num_groups=32,
attn_blocks=[1, 1, 1, 1],
dropout: float = 0.0,
act_fn="silu",
):
super().__init__()
self.layers_per_block = layers_per_block
# in
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# down
self.down_blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, block_out_channel in enumerate(block_out_channels):
input_channels = output_channels
output_channels = block_out_channel
# is_final_block = i == len(block_out_channels) - 1
down_block = ResnetBlock2D(
in_channels=input_channels,
out_channels=output_channels,
downsample_factor=downsample_factors[i],
norm_num_groups=norm_num_groups,
dropout=dropout,
act_fn=act_fn,
)
self.down_blocks.append(down_block)
if attn_blocks[i] == 1:
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
self.down_blocks.append(attention_block)
# out
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.act_fn_out = nn.ReLU()
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
# down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.act_fn_out(hidden_states)
return hidden_states
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
# The code here is for ablation study.
from torch import nn
from torch.nn import functional as F
class SyncNetWav2Lip(nn.Module):
def __init__(self, act_fn="leaky"):
super().__init__()
# input image sequences: (15, 128, 256)
self.visual_encoder = nn.Sequential(
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
)
# input audio sequences: (1, 80, 16)
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
)
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "tanh":
self.act_fn = nn.Tanh()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
elif act_fn == "leaky":
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act_fn(out)
This diff is collapsed.
This diff is collapsed.
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
This diff is collapsed.
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import torch.nn as nn
from einops import rearrange
from .third_party.VideoMAEv2.utils import load_videomae_model
class TREPALoss:
def __init__(
self,
device="cuda",
ckpt_path="/mnt/bn/maliva-gen-ai-v2/chunyu.li/checkpoints/vit_g_hybrid_pt_1200e_ssv2_ft.pth",
):
self.model = load_videomae_model(device, ckpt_path).eval().to(dtype=torch.float16)
self.model.requires_grad_(False)
self.bce_loss = nn.BCELoss()
def __call__(self, videos_fake, videos_real, loss_type="mse"):
batch_size = videos_fake.shape[0]
num_frames = videos_fake.shape[2]
videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w")
videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w")
videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bilinear")
videos_real = F.interpolate(videos_real, size=(224, 224), mode="bilinear")
videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames)
videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames)
# Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1)
videos_real = (videos_real / 2 + 0.5).clamp(0, 1)
feats_fake = self.model.forward_features(videos_fake)
feats_real = self.model.forward_features(videos_real)
feats_fake = F.normalize(feats_fake, p=2, dim=1)
feats_real = F.normalize(feats_real, p=2, dim=1)
return F.mse_loss(feats_fake, feats_real)
if __name__ == "__main__":
# input shape: (b, c, f, h, w)
videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
trepa_loss = TREPALoss(device="cuda")
loss = trepa_loss(videos_fake, videos_real)
print(loss)
import os
import torch
import requests
from tqdm import tqdm
from torchvision import transforms
from .videomaev2_finetune import vit_giant_patch14_224
def to_normalized_float_tensor(vid):
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
# NOTE: for those functions, which generally expect mini-batches, we keep them
# as non-minibatch so that they are applied as if they were 4d (thus image).
# this way, we only apply the transformation in the spatial domain
def resize(vid, size, interpolation='bilinear'):
# NOTE: using bilinear interpolation because we don't work on minibatches
# at this level
scale = None
if isinstance(size, int):
scale = float(size) / min(vid.shape[-2:])
size = None
return torch.nn.functional.interpolate(
vid,
size=size,
scale_factor=scale,
mode=interpolation,
align_corners=False)
class ToFloatTensorInZeroOne(object):
def __call__(self, vid):
return to_normalized_float_tensor(vid)
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return resize(vid, self.size)
def preprocess_videomae(videos):
transform = transforms.Compose(
[ToFloatTensorInZeroOne(),
Resize((224, 224))])
return torch.stack([transform(f) for f in torch.from_numpy(videos)])
def load_videomae_model(device, ckpt_path=None):
if ckpt_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
ckpt_path = os.path.join(current_dir, 'vit_g_hybrid_pt_1200e_ssv2_ft.pth')
if not os.path.exists(ckpt_path):
# download the ckpt to the path
ckpt_url = 'https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth'
response = requests.get(ckpt_url, stream=True, allow_redirects=True)
total_size = int(response.headers.get("content-length", 0))
block_size = 1024
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
with open(ckpt_path, "wb") as fw:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
fw.write(data)
model = vit_giant_patch14_224(
img_size=224,
pretrained=False,
num_classes=174,
all_frames=16,
tubelet_size=2,
drop_path_rate=0.3,
use_mean_pooling=True)
ckpt = torch.load(ckpt_path, map_location='cpu')
for model_key in ['model', 'module']:
if model_key in ckpt:
ckpt = ckpt[model_key]
break
model.load_state_dict(ckpt)
return model.to(device)
\ No newline at end of file
This diff is collapsed.
# Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
import os
import random
import torch
import pickle
import numpy as np
from typing import List, Tuple
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
class FeatureStats:
'''
Class to store statistics of features, including all features and mean/covariance.
Args:
capture_all: Whether to store all the features.
capture_mean_cov: Whether to store mean and covariance.
max_items: Maximum number of items to store.
'''
def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
'''
'''
self.capture_all = capture_all
self.capture_mean_cov = capture_mean_cov
self.max_items = max_items
self.num_items = 0
self.num_features = None
self.all_features = None
self.raw_mean = None
self.raw_cov = None
def set_num_features(self, num_features: int):
'''
Set the number of features diminsions.
Args:
num_features: Number of features diminsions.
'''
if self.num_features is not None:
assert num_features == self.num_features
else:
self.num_features = num_features
self.all_features = []
self.raw_mean = np.zeros([num_features], dtype=np.float64)
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
def is_full(self) -> bool:
'''
Check if the maximum number of samples is reached.
Returns:
True if the storage is full, False otherwise.
'''
return (self.max_items is not None) and (self.num_items >= self.max_items)
def append(self, x: np.ndarray):
'''
Add the newly computed features to the list. Update the mean and covariance.
Args:
x: New features to record.
'''
x = np.asarray(x, dtype=np.float32)
assert x.ndim == 2
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
if self.num_items >= self.max_items:
return
x = x[:self.max_items - self.num_items]
self.set_num_features(x.shape[1])
self.num_items += x.shape[0]
if self.capture_all:
self.all_features.append(x)
if self.capture_mean_cov:
x64 = x.astype(np.float64)
self.raw_mean += x64.sum(axis=0)
self.raw_cov += x64.T @ x64
def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
'''
Add the newly computed PyTorch features to the list. Update the mean and covariance.
Args:
x: New features to record.
rank: Rank of the current GPU.
num_gpus: Total number of GPUs.
'''
assert isinstance(x, torch.Tensor) and x.ndim == 2
assert 0 <= rank < num_gpus
if num_gpus > 1:
ys = []
for src in range(num_gpus):
y = x.clone()
torch.distributed.broadcast(y, src=src)
ys.append(y)
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
self.append(x.cpu().numpy())
def get_all(self) -> np.ndarray:
'''
Get all the stored features as NumPy Array.
Returns:
Concatenation of the stored features.
'''
assert self.capture_all
return np.concatenate(self.all_features, axis=0)
def get_all_torch(self) -> torch.Tensor:
'''
Get all the stored features as PyTorch Tensor.
Returns:
Concatenation of the stored features.
'''
return torch.from_numpy(self.get_all())
def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
'''
Get the mean and covariance of the stored features.
Returns:
Mean and covariance of the stored features.
'''
assert self.capture_mean_cov
mean = self.raw_mean / self.num_items
cov = self.raw_cov / self.num_items
cov = cov - np.outer(mean, mean)
return mean, cov
def save(self, pkl_file: str):
'''
Save the features and statistics to a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with open(pkl_file, 'wb') as f:
pickle.dump(self.__dict__, f)
@staticmethod
def load(pkl_file: str) -> 'FeatureStats':
'''
Load the features and statistics from a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with open(pkl_file, 'rb') as f:
s = pickle.load(f)
obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
obj.__dict__.update(s)
print('Loaded %d features from %s' % (obj.num_items, pkl_file))
return obj
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment