Commit 5c023842 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 增加LatentSync

parent 822b66ca
Pipeline #2211 canceled with stages
# Face detector
This face detector is adapted from `https://github.com/cs-giung/face-detection-pytorch`.
from .s3fd import S3FD
\ No newline at end of file
import time
import numpy as np
import cv2
import torch
from torchvision import transforms
from .nets import S3FDNet
from .box_utils import nms_
PATH_WEIGHT = 'checkpoints/auxiliary/sfd_face.pth'
img_mean = np.array([104., 117., 123.])[:, np.newaxis, np.newaxis].astype('float32')
class S3FD():
def __init__(self, device='cuda'):
tstamp = time.time()
self.device = device
print('[S3FD] loading with', self.device)
self.net = S3FDNet(device=self.device).to(self.device)
state_dict = torch.load(PATH_WEIGHT, map_location=self.device)
self.net.load_state_dict(state_dict)
self.net.eval()
print('[S3FD] finished loading (%.4f sec)' % (time.time() - tstamp))
def detect_faces(self, image, conf_th=0.8, scales=[1]):
w, h = image.shape[1], image.shape[0]
bboxes = np.empty(shape=(0, 5))
with torch.no_grad():
for s in scales:
scaled_img = cv2.resize(image, dsize=(0, 0), fx=s, fy=s, interpolation=cv2.INTER_LINEAR)
scaled_img = np.swapaxes(scaled_img, 1, 2)
scaled_img = np.swapaxes(scaled_img, 1, 0)
scaled_img = scaled_img[[2, 1, 0], :, :]
scaled_img = scaled_img.astype('float32')
scaled_img -= img_mean
scaled_img = scaled_img[[2, 1, 0], :, :]
x = torch.from_numpy(scaled_img).unsqueeze(0).to(self.device)
y = self.net(x)
detections = y.data
scale = torch.Tensor([w, h, w, h])
for i in range(detections.size(1)):
j = 0
while detections[0, i, j, 0] > conf_th:
score = detections[0, i, j, 0]
pt = (detections[0, i, j, 1:] * scale).cpu().numpy()
bbox = (pt[0], pt[1], pt[2], pt[3], score)
bboxes = np.vstack((bboxes, bbox))
j += 1
keep = nms_(bboxes, 0.1)
bboxes = bboxes[keep]
return bboxes
import numpy as np
from itertools import product as product
import torch
from torch.autograd import Function
import warnings
def nms_(dets, thresh):
"""
Courtesy of Ross Girshick
[https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py]
"""
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(int(i))
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return np.array(keep).astype(np.int32)
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def nms(boxes, scores, overlap=0.5, top_k=200):
"""Apply non-maximum suppression at test time to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
scores: (tensor) The class predscores for the img, Shape:[num_priors].
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
top_k: (int) The Maximum number of box preds to consider.
Return:
The indices of the kept boxes with respect to num_priors.
"""
keep = scores.new(scores.size(0)).zero_().long()
if boxes.numel() == 0:
return keep, 0
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = torch.mul(x2 - x1, y2 - y1)
v, idx = scores.sort(0) # sort in ascending order
# I = I[v >= 0.01]
idx = idx[-top_k:] # indices of the top-k largest vals
xx1 = boxes.new()
yy1 = boxes.new()
xx2 = boxes.new()
yy2 = boxes.new()
w = boxes.new()
h = boxes.new()
# keep = torch.Tensor()
count = 0
while idx.numel() > 0:
i = idx[-1] # index of current largest val
# keep.append(i)
keep[count] = i
count += 1
if idx.size(0) == 1:
break
idx = idx[:-1] # remove kept element from view
# load bboxes of next highest vals
with warnings.catch_warnings():
# Ignore UserWarning within this block
warnings.simplefilter("ignore", category=UserWarning)
torch.index_select(x1, 0, idx, out=xx1)
torch.index_select(y1, 0, idx, out=yy1)
torch.index_select(x2, 0, idx, out=xx2)
torch.index_select(y2, 0, idx, out=yy2)
# store element-wise max with next highest score
xx1 = torch.clamp(xx1, min=x1[i])
yy1 = torch.clamp(yy1, min=y1[i])
xx2 = torch.clamp(xx2, max=x2[i])
yy2 = torch.clamp(yy2, max=y2[i])
w.resize_as_(xx2)
h.resize_as_(yy2)
w = xx2 - xx1
h = yy2 - yy1
# check sizes of xx1 and xx2.. after each iteration
w = torch.clamp(w, min=0.0)
h = torch.clamp(h, min=0.0)
inter = w * h
# IoU = i / (area(a) + area(b) - i)
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
union = (rem_areas - inter) + area[i]
IoU = inter / union # store result in iou
# keep only elements with an IoU <= overlap
idx = idx[IoU.le(overlap)]
return keep, count
class Detect(object):
def __init__(self, num_classes=2,
top_k=750, nms_thresh=0.3, conf_thresh=0.05,
variance=[0.1, 0.2], nms_top_k=5000):
self.num_classes = num_classes
self.top_k = top_k
self.nms_thresh = nms_thresh
self.conf_thresh = conf_thresh
self.variance = variance
self.nms_top_k = nms_top_k
def forward(self, loc_data, conf_data, prior_data):
num = loc_data.size(0)
num_priors = prior_data.size(0)
conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1)
batch_priors = prior_data.view(-1, num_priors, 4).expand(num, num_priors, 4)
batch_priors = batch_priors.contiguous().view(-1, 4)
decoded_boxes = decode(loc_data.view(-1, 4), batch_priors, self.variance)
decoded_boxes = decoded_boxes.view(num, num_priors, 4)
output = torch.zeros(num, self.num_classes, self.top_k, 5)
for i in range(num):
boxes = decoded_boxes[i].clone()
conf_scores = conf_preds[i].clone()
for cl in range(1, self.num_classes):
c_mask = conf_scores[cl].gt(self.conf_thresh)
scores = conf_scores[cl][c_mask]
if scores.dim() == 0:
continue
l_mask = c_mask.unsqueeze(1).expand_as(boxes)
boxes_ = boxes[l_mask].view(-1, 4)
ids, count = nms(boxes_, scores, self.nms_thresh, self.nms_top_k)
count = count if count < self.top_k else self.top_k
output[i, cl, :count] = torch.cat((scores[ids[:count]].unsqueeze(1), boxes_[ids[:count]]), 1)
return output
class PriorBox(object):
def __init__(self, input_size, feature_maps,
variance=[0.1, 0.2],
min_sizes=[16, 32, 64, 128, 256, 512],
steps=[4, 8, 16, 32, 64, 128],
clip=False):
super(PriorBox, self).__init__()
self.imh = input_size[0]
self.imw = input_size[1]
self.feature_maps = feature_maps
self.variance = variance
self.min_sizes = min_sizes
self.steps = steps
self.clip = clip
def forward(self):
mean = []
for k, fmap in enumerate(self.feature_maps):
feath = fmap[0]
featw = fmap[1]
for i, j in product(range(feath), range(featw)):
f_kw = self.imw / self.steps[k]
f_kh = self.imh / self.steps[k]
cx = (j + 0.5) / f_kw
cy = (i + 0.5) / f_kh
s_kw = self.min_sizes[k] / self.imw
s_kh = self.min_sizes[k] / self.imh
mean += [cx, cy, s_kw, s_kh]
output = torch.FloatTensor(mean).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from .box_utils import Detect, PriorBox
class L2Norm(nn.Module):
def __init__(self, n_channels, scale):
super(L2Norm, self).__init__()
self.n_channels = n_channels
self.gamma = scale or None
self.eps = 1e-10
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
self.reset_parameters()
def reset_parameters(self):
init.constant_(self.weight, self.gamma)
def forward(self, x):
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
x = torch.div(x, norm)
out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
return out
class S3FDNet(nn.Module):
def __init__(self, device='cuda'):
super(S3FDNet, self).__init__()
self.device = device
self.vgg = nn.ModuleList([
nn.Conv2d(3, 64, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2, ceil_mode=True),
nn.Conv2d(256, 512, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(512, 512, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(512, 1024, 3, 1, padding=6, dilation=6),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 1, 1),
nn.ReLU(inplace=True),
])
self.L2Norm3_3 = L2Norm(256, 10)
self.L2Norm4_3 = L2Norm(512, 8)
self.L2Norm5_3 = L2Norm(512, 5)
self.extras = nn.ModuleList([
nn.Conv2d(1024, 256, 1, 1),
nn.Conv2d(256, 512, 3, 2, padding=1),
nn.Conv2d(512, 128, 1, 1),
nn.Conv2d(128, 256, 3, 2, padding=1),
])
self.loc = nn.ModuleList([
nn.Conv2d(256, 4, 3, 1, padding=1),
nn.Conv2d(512, 4, 3, 1, padding=1),
nn.Conv2d(512, 4, 3, 1, padding=1),
nn.Conv2d(1024, 4, 3, 1, padding=1),
nn.Conv2d(512, 4, 3, 1, padding=1),
nn.Conv2d(256, 4, 3, 1, padding=1),
])
self.conf = nn.ModuleList([
nn.Conv2d(256, 4, 3, 1, padding=1),
nn.Conv2d(512, 2, 3, 1, padding=1),
nn.Conv2d(512, 2, 3, 1, padding=1),
nn.Conv2d(1024, 2, 3, 1, padding=1),
nn.Conv2d(512, 2, 3, 1, padding=1),
nn.Conv2d(256, 2, 3, 1, padding=1),
])
self.softmax = nn.Softmax(dim=-1)
self.detect = Detect()
def forward(self, x):
size = x.size()[2:]
sources = list()
loc = list()
conf = list()
for k in range(16):
x = self.vgg[k](x)
s = self.L2Norm3_3(x)
sources.append(s)
for k in range(16, 23):
x = self.vgg[k](x)
s = self.L2Norm4_3(x)
sources.append(s)
for k in range(23, 30):
x = self.vgg[k](x)
s = self.L2Norm5_3(x)
sources.append(s)
for k in range(30, len(self.vgg)):
x = self.vgg[k](x)
sources.append(x)
# apply extra layers and cache source layer outputs
for k, v in enumerate(self.extras):
x = F.relu(v(x), inplace=True)
if k % 2 == 1:
sources.append(x)
# apply multibox head to source layers
loc_x = self.loc[0](sources[0])
conf_x = self.conf[0](sources[0])
max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True)
conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1)
loc.append(loc_x.permute(0, 2, 3, 1).contiguous())
conf.append(conf_x.permute(0, 2, 3, 1).contiguous())
for i in range(1, len(sources)):
x = sources[i]
conf.append(self.conf[i](x).permute(0, 2, 3, 1).contiguous())
loc.append(self.loc[i](x).permute(0, 2, 3, 1).contiguous())
features_maps = []
for i in range(len(loc)):
feat = []
feat += [loc[i].size(1), loc[i].size(2)]
features_maps += [feat]
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
with torch.no_grad():
self.priorbox = PriorBox(size, features_maps)
self.priors = self.priorbox.forward()
output = self.detect.forward(
loc.view(loc.size(0), -1, 4),
self.softmax(conf.view(conf.size(0), -1, 2)),
self.priors.type(type(x.data)).to(self.device)
)
return output
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import matplotlib.pyplot as plt
class Chart:
def __init__(self):
self.loss_list = []
def add_ckpt(self, ckpt_path, line_name):
ckpt = torch.load(ckpt_path, map_location="cpu")
train_step_list = ckpt["train_step_list"]
train_loss_list = ckpt["train_loss_list"]
val_step_list = ckpt["val_step_list"]
val_loss_list = ckpt["val_loss_list"]
val_step_list = [val_step_list[0]] + val_step_list[4::5]
val_loss_list = [val_loss_list[0]] + val_loss_list[4::5]
self.loss_list.append((line_name, train_step_list, train_loss_list, val_step_list, val_loss_list))
def draw(self, save_path, plot_val=True):
# Global settings
plt.rcParams["font.size"] = 14
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans", "Lucida Grande"]
plt.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"]
# Creating the plot
plt.figure(figsize=(7.766, 4.8)) # Golden ratio
for loss in self.loss_list:
if plot_val:
(line,) = plt.plot(loss[1], loss[2], label=loss[0], linewidth=0.5, alpha=0.5)
line_color = line.get_color()
plt.plot(loss[3], loss[4], linewidth=1.5, color=line_color)
else:
plt.plot(loss[1], loss[2], label=loss[0], linewidth=1)
plt.xlabel("Step")
plt.ylabel("Loss")
legend = plt.legend()
# legend = plt.legend(loc='upper right', bbox_to_anchor=(1, 0.82))
# Adjust the linewidth of legend
for line in legend.get_lines():
line.set_linewidth(2)
plt.savefig(save_path, transparent=True)
plt.close()
if __name__ == "__main__":
chart = Chart()
# chart.add_ckpt("output/syncnet/train-2024_10_25-18:14:43/checkpoints/checkpoint-10000.pt", "w/ self-attn")
# chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "w/o self-attn")
chart.add_ckpt("output/syncnet/train-2024_10_24-21:03:11/checkpoints/checkpoint-10000.pt", "Dim 512")
chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "Dim 2048")
chart.add_ckpt("output/syncnet/train-2024_10_24-22:37:04/checkpoints/checkpoint-10000.pt", "Dim 4096")
chart.add_ckpt("output/syncnet/train-2024_10_25-02:30:17/checkpoints/checkpoint-10000.pt", "Dim 6144")
chart.draw("ablation.pdf", plot_val=True)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mediapipe as mp
import cv2
from decord import VideoReader
from einops import rearrange
import os
import numpy as np
import torch
import tqdm
from eval.fvd import compute_our_fvd
class FVD:
def __init__(self, resolution=(224, 224)):
self.face_detector = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5)
self.resolution = resolution
def detect_face(self, image):
height, width = image.shape[:2]
# Process the image and detect faces.
results = self.face_detector.process(image)
if not results.detections: # Face not detected
raise Exception("Face not detected")
detection = results.detections[0] # Only use the first face in the image
bounding_box = detection.location_data.relative_bounding_box
xmin = int(bounding_box.xmin * width)
ymin = int(bounding_box.ymin * height)
face_width = int(bounding_box.width * width)
face_height = int(bounding_box.height * height)
# Crop the image to the bounding box.
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = min(width, xmin + face_width)
ymax = min(height, ymin + face_height)
image = image[ymin:ymax, xmin:xmax]
return image
def detect_video(self, video_path, real: bool = True):
vr = VideoReader(video_path)
video_frames = vr[20:36].asnumpy() # Use one frame per second
vr.seek(0) # avoid memory leak
faces = []
for frame in video_frames:
face = self.detect_face(frame)
face = cv2.resize(face, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_AREA)
faces.append(face)
if len(faces) != 16:
return None
faces = np.stack(faces, axis=0) # (f, h, w, c)
faces = torch.from_numpy(faces)
return faces
def eval_fvd(real_videos_dir, fake_videos_dir):
fvd = FVD()
real_features_list = []
fake_features_list = []
for file in tqdm.tqdm(os.listdir(fake_videos_dir)):
if file.endswith(".mp4"):
real_video_path = os.path.join(real_videos_dir, file.replace("_out.mp4", ".mp4"))
fake_video_path = os.path.join(fake_videos_dir, file)
real_features = fvd.detect_video(real_video_path, real=True)
fake_features = fvd.detect_video(fake_video_path, real=False)
if real_features is None or fake_features is None:
continue
real_features_list.append(real_features)
fake_features_list.append(fake_features)
real_features = torch.stack(real_features_list) / 255.0
fake_features = torch.stack(fake_features_list) / 255.0
print(compute_our_fvd(real_features, fake_features, device="cpu"))
if __name__ == "__main__":
real_videos_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented/cross"
fake_videos_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented/latentsync_cross"
eval_fvd(real_videos_dir, fake_videos_dir)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import tqdm
from statistics import fmean
from eval.syncnet import SyncNetEval
from eval.syncnet_detect import SyncNetDetector
from latentsync.utils.util import red_text
import torch
def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"):
syncnet_detector(video_path=video_path, min_track=50)
crop_videos = os.listdir(os.path.join(detect_results_dir, "crop"))
if crop_videos == []:
raise Exception(red_text(f"Face not detected in {video_path}"))
av_offset_list = []
conf_list = []
for video in crop_videos:
av_offset, _, conf = syncnet.evaluate(
video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir
)
av_offset_list.append(av_offset)
conf_list.append(conf)
av_offset = int(fmean(av_offset_list))
conf = fmean(conf_list)
print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}")
return av_offset, conf
def main():
parser = argparse.ArgumentParser(description="SyncNet")
parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="")
parser.add_argument("--video_path", type=str, default=None, help="")
parser.add_argument("--videos_dir", type=str, default="/root/processed")
parser.add_argument("--temp_dir", type=str, default="temp", help="")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
syncnet = SyncNetEval(device=device)
syncnet.loadParameters(args.initial_model)
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
if args.video_path is not None:
syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir)
else:
sync_conf_list = []
video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")])
for video_name in tqdm.tqdm(video_names):
try:
_, conf = syncnet_eval(
syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir
)
sync_conf_list.append(conf)
except Exception as e:
print(e)
print(f"The average sync confidence is {fmean(sync_conf_list):.02f}")
if __name__ == "__main__":
main()
#!/bin/bash
python -m eval.eval_sync_conf --video_path "RD_Radio1_000_006_out.mp4"
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from einops import rearrange
from latentsync.models.syncnet import SyncNet
from latentsync.data.syncnet_dataset import SyncNetDataset
from diffusers import AutoencoderKL
from omegaconf import OmegaConf
from accelerate.utils import set_seed
def main(config):
set_seed(config.run.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
if config.data.latent_space:
vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-inpainting", subfolder="vae", revision="fp16", torch_dtype=torch.float16
)
vae.requires_grad_(False)
vae.to(device)
# Dataset and Dataloader setup
dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
test_dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=config.data.batch_size,
shuffle=False,
num_workers=config.data.num_workers,
drop_last=False,
worker_init_fn=dataset.worker_init_fn,
)
# Model
syncnet = SyncNet(OmegaConf.to_container(config.model)).to(device)
print(f"Load checkpoint from: {config.ckpt.inference_ckpt_path}")
checkpoint = torch.load(config.ckpt.inference_ckpt_path, map_location=device)
syncnet.load_state_dict(checkpoint["state_dict"])
syncnet.to(dtype=torch.float16)
syncnet.requires_grad_(False)
syncnet.eval()
global_step = 0
num_val_batches = config.data.num_val_samples // config.data.batch_size
progress_bar = tqdm(range(0, num_val_batches), initial=0, desc="Testing accuracy")
num_correct_preds = 0
num_total_preds = 0
while True:
for step, batch in enumerate(test_dataloader):
### >>>> Test >>>> ###
frames = batch["frames"].to(device, dtype=torch.float16)
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
y = batch["y"].to(device, dtype=torch.float16).squeeze(1)
if config.data.latent_space:
frames = rearrange(frames, "b f c h w -> (b f) c h w")
with torch.no_grad():
frames = vae.encode(frames).latent_dist.sample() * 0.18215
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
else:
frames = rearrange(frames, "b f c h w -> b (f c) h w")
if config.data.lower_half:
height = frames.shape[2]
frames = frames[:, :, height // 2 :, :]
with torch.no_grad():
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
preds = (sims > 0.5).to(dtype=torch.float16)
num_correct_preds += (preds == y).sum().item()
num_total_preds += len(sims)
progress_bar.update(1)
global_step += 1
if global_step >= num_val_batches:
progress_bar.close()
print(f"Accuracy score: {num_correct_preds / num_total_preds*100:.2f}%")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to test the accuracy of expert lip-sync discriminator")
parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_latent.yaml")
args = parser.parse_args()
# Load a configuration file
config = OmegaConf.load(args.config_path)
main(config)
#!/bin/bash
python -m eval.eval_syncnet_acc --config_path "configs/syncnet/syncnet_16_pixel.yaml"
# Adapted from https://github.com/universome/fvd-comparison/blob/master/our_fvd.py
from typing import Tuple
import scipy
import numpy as np
import torch
def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
mu_gen, sigma_gen = compute_stats(feats_fake)
mu_real, sigma_real = compute_stats(feats_real)
m = np.square(mu_gen - mu_real).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
return float(fid)
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
mu = feats.mean(axis=0) # [d]
sigma = np.cov(feats, rowvar=False) # [d, d]
return mu, sigma
@torch.no_grad()
def compute_our_fvd(videos_fake: np.ndarray, videos_real: np.ndarray, device: str = "cuda") -> float:
i3d_path = "checkpoints/auxiliary/i3d_torchscript.pt"
i3d_kwargs = dict(
rescale=False, resize=False, return_features=True
) # Return raw features before the softmax layer.
with open(i3d_path, "rb") as f:
i3d_model = torch.jit.load(f).eval().to(device)
videos_fake = videos_fake.permute(0, 4, 1, 2, 3).to(device)
videos_real = videos_real.permute(0, 4, 1, 2, 3).to(device)
feats_fake = i3d_model(videos_fake, **i3d_kwargs).cpu().numpy()
feats_real = i3d_model(videos_real, **i3d_kwargs).cpu().numpy()
return compute_fvd(feats_fake, feats_real)
def main():
# input shape: (b, f, h, w, c)
videos_fake = torch.rand(10, 16, 224, 224, 3)
videos_real = torch.rand(10, 16, 224, 224, 3)
our_fvd_result = compute_our_fvd(videos_fake, videos_real)
print(f"[FVD scores] Ours: {our_fvd_result}")
if __name__ == "__main__":
main()
# Adapted from https://github.com/SSL92/hyperIQA/blob/master/models.py
import torch as torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
import math
import torch.utils.model_zoo as model_zoo
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
class HyperNet(nn.Module):
"""
Hyper network for learning perceptual rules.
Args:
lda_out_channels: local distortion aware module output size.
hyper_in_channels: input feature channels for hyper network.
target_in_size: input vector size for target network.
target_fc(i)_size: fully connection layer size of target network.
feature_size: input feature map width/height for hyper network.
Note:
For size match, input args must satisfy: 'target_fc(i)_size * target_fc(i+1)_size' is divisible by 'feature_size ^ 2'.
"""
def __init__(self, lda_out_channels, hyper_in_channels, target_in_size, target_fc1_size, target_fc2_size, target_fc3_size, target_fc4_size, feature_size):
super(HyperNet, self).__init__()
self.hyperInChn = hyper_in_channels
self.target_in_size = target_in_size
self.f1 = target_fc1_size
self.f2 = target_fc2_size
self.f3 = target_fc3_size
self.f4 = target_fc4_size
self.feature_size = feature_size
self.res = resnet50_backbone(lda_out_channels, target_in_size, pretrained=True)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
# Conv layers for resnet output features
self.conv1 = nn.Sequential(
nn.Conv2d(2048, 1024, 1, padding=(0, 0)),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 512, 1, padding=(0, 0)),
nn.ReLU(inplace=True),
nn.Conv2d(512, self.hyperInChn, 1, padding=(0, 0)),
nn.ReLU(inplace=True)
)
# Hyper network part, conv for generating target fc weights, fc for generating target fc biases
self.fc1w_conv = nn.Conv2d(self.hyperInChn, int(self.target_in_size * self.f1 / feature_size ** 2), 3, padding=(1, 1))
self.fc1b_fc = nn.Linear(self.hyperInChn, self.f1)
self.fc2w_conv = nn.Conv2d(self.hyperInChn, int(self.f1 * self.f2 / feature_size ** 2), 3, padding=(1, 1))
self.fc2b_fc = nn.Linear(self.hyperInChn, self.f2)
self.fc3w_conv = nn.Conv2d(self.hyperInChn, int(self.f2 * self.f3 / feature_size ** 2), 3, padding=(1, 1))
self.fc3b_fc = nn.Linear(self.hyperInChn, self.f3)
self.fc4w_conv = nn.Conv2d(self.hyperInChn, int(self.f3 * self.f4 / feature_size ** 2), 3, padding=(1, 1))
self.fc4b_fc = nn.Linear(self.hyperInChn, self.f4)
self.fc5w_fc = nn.Linear(self.hyperInChn, self.f4)
self.fc5b_fc = nn.Linear(self.hyperInChn, 1)
# initialize
for i, m_name in enumerate(self._modules):
if i > 2:
nn.init.kaiming_normal_(self._modules[m_name].weight.data)
def forward(self, img):
feature_size = self.feature_size
res_out = self.res(img)
# input vector for target net
target_in_vec = res_out['target_in_vec'].reshape(-1, self.target_in_size, 1, 1)
# input features for hyper net
hyper_in_feat = self.conv1(res_out['hyper_in_feat']).reshape(-1, self.hyperInChn, feature_size, feature_size)
# generating target net weights & biases
target_fc1w = self.fc1w_conv(hyper_in_feat).reshape(-1, self.f1, self.target_in_size, 1, 1)
target_fc1b = self.fc1b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f1)
target_fc2w = self.fc2w_conv(hyper_in_feat).reshape(-1, self.f2, self.f1, 1, 1)
target_fc2b = self.fc2b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f2)
target_fc3w = self.fc3w_conv(hyper_in_feat).reshape(-1, self.f3, self.f2, 1, 1)
target_fc3b = self.fc3b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f3)
target_fc4w = self.fc4w_conv(hyper_in_feat).reshape(-1, self.f4, self.f3, 1, 1)
target_fc4b = self.fc4b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, self.f4)
target_fc5w = self.fc5w_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, 1, self.f4, 1, 1)
target_fc5b = self.fc5b_fc(self.pool(hyper_in_feat).squeeze()).reshape(-1, 1)
out = {}
out['target_in_vec'] = target_in_vec
out['target_fc1w'] = target_fc1w
out['target_fc1b'] = target_fc1b
out['target_fc2w'] = target_fc2w
out['target_fc2b'] = target_fc2b
out['target_fc3w'] = target_fc3w
out['target_fc3b'] = target_fc3b
out['target_fc4w'] = target_fc4w
out['target_fc4b'] = target_fc4b
out['target_fc5w'] = target_fc5w
out['target_fc5b'] = target_fc5b
return out
class TargetNet(nn.Module):
"""
Target network for quality prediction.
"""
def __init__(self, paras):
super(TargetNet, self).__init__()
self.l1 = nn.Sequential(
TargetFC(paras['target_fc1w'], paras['target_fc1b']),
nn.Sigmoid(),
)
self.l2 = nn.Sequential(
TargetFC(paras['target_fc2w'], paras['target_fc2b']),
nn.Sigmoid(),
)
self.l3 = nn.Sequential(
TargetFC(paras['target_fc3w'], paras['target_fc3b']),
nn.Sigmoid(),
)
self.l4 = nn.Sequential(
TargetFC(paras['target_fc4w'], paras['target_fc4b']),
nn.Sigmoid(),
TargetFC(paras['target_fc5w'], paras['target_fc5b']),
)
def forward(self, x):
q = self.l1(x)
# q = F.dropout(q)
q = self.l2(q)
q = self.l3(q)
q = self.l4(q).squeeze()
return q
class TargetFC(nn.Module):
"""
Fully connection operations for target net
Note:
Weights & biases are different for different images in a batch,
thus here we use group convolution for calculating images in a batch with individual weights & biases.
"""
def __init__(self, weight, bias):
super(TargetFC, self).__init__()
self.weight = weight
self.bias = bias
def forward(self, input_):
input_re = input_.reshape(-1, input_.shape[0] * input_.shape[1], input_.shape[2], input_.shape[3])
weight_re = self.weight.reshape(self.weight.shape[0] * self.weight.shape[1], self.weight.shape[2], self.weight.shape[3], self.weight.shape[4])
bias_re = self.bias.reshape(self.bias.shape[0] * self.bias.shape[1])
out = F.conv2d(input=input_re, weight=weight_re, bias=bias_re, groups=self.weight.shape[0])
return out.reshape(input_.shape[0], self.weight.shape[1], input_.shape[2], input_.shape[3])
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNetBackbone(nn.Module):
def __init__(self, lda_out_channels, in_chn, block, layers, num_classes=1000):
super(ResNetBackbone, self).__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# local distortion aware module
self.lda1_pool = nn.Sequential(
nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False),
nn.AvgPool2d(7, stride=7),
)
self.lda1_fc = nn.Linear(16 * 64, lda_out_channels)
self.lda2_pool = nn.Sequential(
nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False),
nn.AvgPool2d(7, stride=7),
)
self.lda2_fc = nn.Linear(32 * 16, lda_out_channels)
self.lda3_pool = nn.Sequential(
nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False),
nn.AvgPool2d(7, stride=7),
)
self.lda3_fc = nn.Linear(64 * 4, lda_out_channels)
self.lda4_pool = nn.AvgPool2d(7, stride=7)
self.lda4_fc = nn.Linear(2048, in_chn - lda_out_channels * 3)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# initialize
nn.init.kaiming_normal_(self.lda1_pool._modules['0'].weight.data)
nn.init.kaiming_normal_(self.lda2_pool._modules['0'].weight.data)
nn.init.kaiming_normal_(self.lda3_pool._modules['0'].weight.data)
nn.init.kaiming_normal_(self.lda1_fc.weight.data)
nn.init.kaiming_normal_(self.lda2_fc.weight.data)
nn.init.kaiming_normal_(self.lda3_fc.weight.data)
nn.init.kaiming_normal_(self.lda4_fc.weight.data)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
# the same effect as lda operation in the paper, but save much more memory
lda_1 = self.lda1_fc(self.lda1_pool(x).reshape(x.size(0), -1))
x = self.layer2(x)
lda_2 = self.lda2_fc(self.lda2_pool(x).reshape(x.size(0), -1))
x = self.layer3(x)
lda_3 = self.lda3_fc(self.lda3_pool(x).reshape(x.size(0), -1))
x = self.layer4(x)
lda_4 = self.lda4_fc(self.lda4_pool(x).reshape(x.size(0), -1))
vec = torch.cat((lda_1, lda_2, lda_3, lda_4), 1)
out = {}
out['hyper_in_feat'] = x
out['target_in_vec'] = vec
return out
def resnet50_backbone(lda_out_channels, in_chn, pretrained=False, **kwargs):
"""Constructs a ResNet-50 model_hyper.
Args:
pretrained (bool): If True, returns a model_hyper pre-trained on ImageNet
"""
model = ResNetBackbone(lda_out_channels, in_chn, Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
save_model = model_zoo.load_url(model_urls['resnet50'])
model_dict = model.state_dict()
state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
model_dict.update(state_dict)
model.load_state_dict(model_dict)
else:
model.apply(weights_init_xavier)
return model
def weights_init_xavier(m):
classname = m.__class__.__name__
# print(classname)
# if isinstance(m, nn.Conv2d):
if classname.find('Conv') != -1:
init.kaiming_normal_(m.weight.data)
elif classname.find('Linear') != -1:
init.kaiming_normal_(m.weight.data)
elif classname.find('BatchNorm2d') != -1:
init.uniform_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from tqdm import tqdm
def inference_video_from_dir(input_dir, output_dir, unet_config_path, ckpt_path):
os.makedirs(output_dir, exist_ok=True)
video_names = sorted([f for f in os.listdir(input_dir) if f.endswith(".mp4")])
for video_name in tqdm(video_names):
video_path = os.path.join(input_dir, video_name)
audio_path = os.path.join(input_dir, video_name.replace(".mp4", "_audio.wav"))
video_out_path = os.path.join(output_dir, video_name.replace(".mp4", "_out.mp4"))
inference_command = f"python inference.py --unet_config_path {unet_config_path} --video_path {video_path} --audio_path {audio_path} --video_out_path {video_out_path} --inference_ckpt_path {ckpt_path} --seed 1247"
subprocess.run(inference_command, shell=True)
if __name__ == "__main__":
input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/segmented/cross"
output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/segmented/latentsync_cross"
unet_config_path = "configs/unet/unet_latent_16_diffusion.yaml"
ckpt_path = "output/unet/train-2024_10_08-16:23:43/checkpoints/checkpoint-1920000.pt"
inference_video_from_dir(input_dir, output_dir, unet_config_path, ckpt_path)
from .syncnet_eval import SyncNetEval
# https://github.com/joonson/syncnet_python/blob/master/SyncNetModel.py
import torch
import torch.nn as nn
def save(model, filename):
with open(filename, "wb") as f:
torch.save(model, f)
print("%s saved." % filename)
def load(filename):
net = torch.load(filename)
return net
class S(nn.Module):
def __init__(self, num_layers_in_fc_layers=1024):
super(S, self).__init__()
self.__nFeatures__ = 24
self.__nChs__ = 32
self.__midChs__ = 32
self.netcnnaud = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(1, 1), stride=(1, 1)),
nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(192),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 2)),
nn.Conv2d(192, 384, kernel_size=(3, 3), padding=(1, 1)),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
nn.Conv2d(256, 512, kernel_size=(5, 4), padding=(0, 0)),
nn.BatchNorm2d(512),
nn.ReLU(),
)
self.netfcaud = nn.Sequential(
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, num_layers_in_fc_layers),
)
self.netfclip = nn.Sequential(
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, num_layers_in_fc_layers),
)
self.netcnnlip = nn.Sequential(
nn.Conv3d(3, 96, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=0),
nn.BatchNorm3d(96),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2)),
nn.Conv3d(96, 256, kernel_size=(1, 5, 5), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.Conv3d(256, 256, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2)),
nn.Conv3d(256, 512, kernel_size=(1, 6, 6), padding=0),
nn.BatchNorm3d(512),
nn.ReLU(inplace=True),
)
def forward_aud(self, x):
mid = self.netcnnaud(x)
# N x ch x 24 x M
mid = mid.view((mid.size()[0], -1))
# N x (ch x 24)
out = self.netfcaud(mid)
return out
def forward_lip(self, x):
mid = self.netcnnlip(x)
mid = mid.view((mid.size()[0], -1))
# N x (ch x 24)
out = self.netfclip(mid)
return out
def forward_lipfeat(self, x):
mid = self.netcnnlip(x)
out = mid.view((mid.size()[0], -1))
# N x (ch x 24)
return out
# Adapted from https://github.com/joonson/syncnet_python/blob/master/SyncNetInstance.py
import torch
import numpy
import time, pdb, argparse, subprocess, os, math, glob
import cv2
import python_speech_features
from scipy import signal
from scipy.io import wavfile
from .syncnet import S
from shutil import rmtree
# ==================== Get OFFSET ====================
# Video 25 FPS, Audio 16000HZ
def calc_pdist(feat1, feat2, vshift=10):
win_size = vshift * 2 + 1
feat2p = torch.nn.functional.pad(feat2, (0, 0, vshift, vshift))
dists = []
for i in range(0, len(feat1)):
dists.append(
torch.nn.functional.pairwise_distance(feat1[[i], :].repeat(win_size, 1), feat2p[i : i + win_size, :])
)
return dists
# ==================== MAIN DEF ====================
class SyncNetEval(torch.nn.Module):
def __init__(self, dropout=0, num_layers_in_fc_layers=1024, device="cpu"):
super().__init__()
self.__S__ = S(num_layers_in_fc_layers=num_layers_in_fc_layers).to(device)
self.device = device
def evaluate(self, video_path, temp_dir="temp", batch_size=20, vshift=15):
self.__S__.eval()
# ========== ==========
# Convert files
# ========== ==========
if os.path.exists(temp_dir):
rmtree(temp_dir)
os.makedirs(temp_dir)
# temp_video_path = os.path.join(temp_dir, "temp.mp4")
# command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -vf scale='224:224' {temp_video_path}"
# subprocess.call(command, shell=True)
command = (
f"ffmpeg -loglevel error -nostdin -y -i {video_path} -f image2 {os.path.join(temp_dir, '%06d.jpg')}"
)
subprocess.call(command, shell=True, stdout=None)
command = f"ffmpeg -loglevel error -nostdin -y -i {video_path} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {os.path.join(temp_dir, 'audio.wav')}"
subprocess.call(command, shell=True, stdout=None)
# ========== ==========
# Load video
# ========== ==========
images = []
flist = glob.glob(os.path.join(temp_dir, "*.jpg"))
flist.sort()
for fname in flist:
img_input = cv2.imread(fname)
img_input = cv2.resize(img_input, (224, 224)) # HARD CODED, CHANGE BEFORE RELEASE
images.append(img_input)
im = numpy.stack(images, axis=3)
im = numpy.expand_dims(im, axis=0)
im = numpy.transpose(im, (0, 3, 4, 1, 2))
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
# ========== ==========
# Load audio
# ========== ==========
sample_rate, audio = wavfile.read(os.path.join(temp_dir, "audio.wav"))
mfcc = zip(*python_speech_features.mfcc(audio, sample_rate))
mfcc = numpy.stack([numpy.array(i) for i in mfcc])
cc = numpy.expand_dims(numpy.expand_dims(mfcc, axis=0), axis=0)
cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
# ========== ==========
# Check audio and video input length
# ========== ==========
# if (float(len(audio)) / 16000) != (float(len(images)) / 25):
# print(
# "WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."
# % (float(len(audio)) / 16000, float(len(images)) / 25)
# )
min_length = min(len(images), math.floor(len(audio) / 640))
# ========== ==========
# Generate video and audio feats
# ========== ==========
lastframe = min_length - 5
im_feat = []
cc_feat = []
tS = time.time()
for i in range(0, lastframe, batch_size):
im_batch = [imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + batch_size))]
im_in = torch.cat(im_batch, 0)
im_out = self.__S__.forward_lip(im_in.to(self.device))
im_feat.append(im_out.data.cpu())
cc_batch = [
cct[:, :, :, vframe * 4 : vframe * 4 + 20] for vframe in range(i, min(lastframe, i + batch_size))
]
cc_in = torch.cat(cc_batch, 0)
cc_out = self.__S__.forward_aud(cc_in.to(self.device))
cc_feat.append(cc_out.data.cpu())
im_feat = torch.cat(im_feat, 0)
cc_feat = torch.cat(cc_feat, 0)
# ========== ==========
# Compute offset
# ========== ==========
dists = calc_pdist(im_feat, cc_feat, vshift=vshift)
mean_dists = torch.mean(torch.stack(dists, 1), 1)
min_dist, minidx = torch.min(mean_dists, 0)
av_offset = vshift - minidx
conf = torch.median(mean_dists) - min_dist
fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
# fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
fconf = torch.median(mean_dists).numpy() - fdist
framewise_conf = signal.medfilt(fconf, kernel_size=9)
# numpy.set_printoptions(formatter={"float": "{: 0.3f}".format})
rmtree(temp_dir)
return av_offset.item(), min_dist.item(), conf.item()
def extract_feature(self, opt, videofile):
self.__S__.eval()
# ========== ==========
# Load video
# ========== ==========
cap = cv2.VideoCapture(videofile)
frame_num = 1
images = []
while frame_num:
frame_num += 1
ret, image = cap.read()
if ret == 0:
break
images.append(image)
im = numpy.stack(images, axis=3)
im = numpy.expand_dims(im, axis=0)
im = numpy.transpose(im, (0, 3, 4, 1, 2))
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
# ========== ==========
# Generate video feats
# ========== ==========
lastframe = len(images) - 4
im_feat = []
tS = time.time()
for i in range(0, lastframe, opt.batch_size):
im_batch = [
imtv[:, :, vframe : vframe + 5, :, :] for vframe in range(i, min(lastframe, i + opt.batch_size))
]
im_in = torch.cat(im_batch, 0)
im_out = self.__S__.forward_lipfeat(im_in.to(self.device))
im_feat.append(im_out.data.cpu())
im_feat = torch.cat(im_feat, 0)
# ========== ==========
# Compute offset
# ========== ==========
print("Compute time %.3f sec." % (time.time() - tS))
return im_feat
def loadParameters(self, path):
loaded_state = torch.load(path, map_location=lambda storage, loc: storage)
self_state = self.__S__.state_dict()
for name, param in loaded_state.items():
self_state[name].copy_(param)
# Adapted from https://github.com/joonson/syncnet_python/blob/master/run_pipeline.py
import os, pdb, subprocess, glob, cv2
import numpy as np
from shutil import rmtree
import torch
from scenedetect.video_manager import VideoManager
from scenedetect.scene_manager import SceneManager
from scenedetect.stats_manager import StatsManager
from scenedetect.detectors import ContentDetector
from scipy.interpolate import interp1d
from scipy.io import wavfile
from scipy import signal
from eval.detectors import S3FD
class SyncNetDetector:
def __init__(self, device, detect_results_dir="detect_results"):
self.s3f_detector = S3FD(device=device)
self.detect_results_dir = detect_results_dir
def __call__(self, video_path: str, min_track=50, scale=False):
crop_dir = os.path.join(self.detect_results_dir, "crop")
video_dir = os.path.join(self.detect_results_dir, "video")
frames_dir = os.path.join(self.detect_results_dir, "frames")
temp_dir = os.path.join(self.detect_results_dir, "temp")
# ========== DELETE EXISTING DIRECTORIES ==========
if os.path.exists(crop_dir):
rmtree(crop_dir)
if os.path.exists(video_dir):
rmtree(video_dir)
if os.path.exists(frames_dir):
rmtree(frames_dir)
if os.path.exists(temp_dir):
rmtree(temp_dir)
# ========== MAKE NEW DIRECTORIES ==========
os.makedirs(crop_dir)
os.makedirs(video_dir)
os.makedirs(frames_dir)
os.makedirs(temp_dir)
# ========== CONVERT VIDEO AND EXTRACT FRAMES ==========
if scale:
scaled_video_path = os.path.join(video_dir, "scaled.mp4")
command = f"ffmpeg -loglevel error -y -nostdin -i {video_path} -vf scale='224:224' {scaled_video_path}"
subprocess.run(command, shell=True)
video_path = scaled_video_path
command = f"ffmpeg -y -nostdin -loglevel error -i {video_path} -qscale:v 2 -async 1 -r 25 {os.path.join(video_dir, 'video.mp4')}"
subprocess.run(command, shell=True, stdout=None)
command = f"ffmpeg -y -nostdin -loglevel error -i {os.path.join(video_dir, 'video.mp4')} -qscale:v 2 -f image2 {os.path.join(frames_dir, '%06d.jpg')}"
subprocess.run(command, shell=True, stdout=None)
command = f"ffmpeg -y -nostdin -loglevel error -i {os.path.join(video_dir, 'video.mp4')} -ac 1 -vn -acodec pcm_s16le -ar 16000 {os.path.join(video_dir, 'audio.wav')}"
subprocess.run(command, shell=True, stdout=None)
faces = self.detect_face(frames_dir)
scene = self.scene_detect(video_dir)
# Face tracking
alltracks = []
for shot in scene:
if shot[1].frame_num - shot[0].frame_num >= min_track:
alltracks.extend(self.track_face(faces[shot[0].frame_num : shot[1].frame_num], min_track=min_track))
# Face crop
for ii, track in enumerate(alltracks):
self.crop_video(track, os.path.join(crop_dir, "%05d" % ii), frames_dir, 25, temp_dir, video_dir)
rmtree(temp_dir)
def scene_detect(self, video_dir):
video_manager = VideoManager([os.path.join(video_dir, "video.mp4")])
stats_manager = StatsManager()
scene_manager = SceneManager(stats_manager)
# Add ContentDetector algorithm (constructor takes detector options like threshold).
scene_manager.add_detector(ContentDetector())
base_timecode = video_manager.get_base_timecode()
video_manager.set_downscale_factor()
video_manager.start()
scene_manager.detect_scenes(frame_source=video_manager)
scene_list = scene_manager.get_scene_list(base_timecode)
if scene_list == []:
scene_list = [(video_manager.get_base_timecode(), video_manager.get_current_timecode())]
return scene_list
def track_face(self, scenefaces, num_failed_det=25, min_track=50, min_face_size=100):
iouThres = 0.5 # Minimum IOU between consecutive face detections
tracks = []
while True:
track = []
for framefaces in scenefaces:
for face in framefaces:
if track == []:
track.append(face)
framefaces.remove(face)
elif face["frame"] - track[-1]["frame"] <= num_failed_det:
iou = bounding_box_iou(face["bbox"], track[-1]["bbox"])
if iou > iouThres:
track.append(face)
framefaces.remove(face)
continue
else:
break
if track == []:
break
elif len(track) > min_track:
framenum = np.array([f["frame"] for f in track])
bboxes = np.array([np.array(f["bbox"]) for f in track])
frame_i = np.arange(framenum[0], framenum[-1] + 1)
bboxes_i = []
for ij in range(0, 4):
interpfn = interp1d(framenum, bboxes[:, ij])
bboxes_i.append(interpfn(frame_i))
bboxes_i = np.stack(bboxes_i, axis=1)
if (
max(np.mean(bboxes_i[:, 2] - bboxes_i[:, 0]), np.mean(bboxes_i[:, 3] - bboxes_i[:, 1]))
> min_face_size
):
tracks.append({"frame": frame_i, "bbox": bboxes_i})
return tracks
def detect_face(self, frames_dir, facedet_scale=0.25):
flist = glob.glob(os.path.join(frames_dir, "*.jpg"))
flist.sort()
dets = []
for fidx, fname in enumerate(flist):
image = cv2.imread(fname)
image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
bboxes = self.s3f_detector.detect_faces(image_np, conf_th=0.9, scales=[facedet_scale])
dets.append([])
for bbox in bboxes:
dets[-1].append({"frame": fidx, "bbox": (bbox[:-1]).tolist(), "conf": bbox[-1]})
return dets
def crop_video(self, track, cropfile, frames_dir, frame_rate, temp_dir, video_dir, crop_scale=0.4):
flist = glob.glob(os.path.join(frames_dir, "*.jpg"))
flist.sort()
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
vOut = cv2.VideoWriter(cropfile + "t.mp4", fourcc, frame_rate, (224, 224))
dets = {"x": [], "y": [], "s": []}
for det in track["bbox"]:
dets["s"].append(max((det[3] - det[1]), (det[2] - det[0])) / 2)
dets["y"].append((det[1] + det[3]) / 2) # crop center x
dets["x"].append((det[0] + det[2]) / 2) # crop center y
# Smooth detections
dets["s"] = signal.medfilt(dets["s"], kernel_size=13)
dets["x"] = signal.medfilt(dets["x"], kernel_size=13)
dets["y"] = signal.medfilt(dets["y"], kernel_size=13)
for fidx, frame in enumerate(track["frame"]):
cs = crop_scale
bs = dets["s"][fidx] # Detection box size
bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount
image = cv2.imread(flist[frame])
frame = np.pad(image, ((bsi, bsi), (bsi, bsi), (0, 0)), "constant", constant_values=(110, 110))
my = dets["y"][fidx] + bsi # BBox center Y
mx = dets["x"][fidx] + bsi # BBox center X
face = frame[int(my - bs) : int(my + bs * (1 + 2 * cs)), int(mx - bs * (1 + cs)) : int(mx + bs * (1 + cs))]
vOut.write(cv2.resize(face, (224, 224)))
audiotmp = os.path.join(temp_dir, "audio.wav")
audiostart = (track["frame"][0]) / frame_rate
audioend = (track["frame"][-1] + 1) / frame_rate
vOut.release()
# ========== CROP AUDIO FILE ==========
command = "ffmpeg -y -nostdin -loglevel error -i %s -ss %.3f -to %.3f %s" % (
os.path.join(video_dir, "audio.wav"),
audiostart,
audioend,
audiotmp,
)
output = subprocess.run(command, shell=True, stdout=None)
sample_rate, audio = wavfile.read(audiotmp)
# ========== COMBINE AUDIO AND VIDEO FILES ==========
command = "ffmpeg -y -nostdin -loglevel error -i %st.mp4 -i %s -c:v copy -c:a aac %s.mp4" % (
cropfile,
audiotmp,
cropfile,
)
output = subprocess.run(command, shell=True, stdout=None)
os.remove(cropfile + "t.mp4")
return {"track": track, "proc_track": dets}
def bounding_box_iou(boxA, boxB):
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA) * max(0, yB - yA)
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
iou = interArea / float(boxAArea + boxBArea - interArea)
return iou
#!/bin/bash
python -m scripts.inference \
--unet_config_path "configs/unet/second_stage.yaml" \
--inference_ckpt_path "checkpoints/latentsync_unet.pt" \
--guidance_scale 1.0 \
--video_path "assets/demo1_video.mp4" \
--audio_path "assets/demo1_audio.wav" \
--video_out_path "video_out.mp4"
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from torch.utils.data import Dataset
import torch
import random
from ..utils.util import gather_video_paths_recursively
from ..utils.image_processor import ImageProcessor
from ..utils.audio import melspectrogram
import math
from decord import AudioReader, VideoReader, cpu
class SyncNetDataset(Dataset):
def __init__(self, data_dir: str, fileslist: str, config):
if fileslist != "":
with open(fileslist) as file:
self.video_paths = [line.rstrip() for line in file]
elif data_dir != "":
self.video_paths = gather_video_paths_recursively(data_dir)
else:
raise ValueError("data_dir and fileslist cannot be both empty")
self.resolution = config.data.resolution
self.num_frames = config.data.num_frames
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
self.audio_sample_rate = config.data.audio_sample_rate
self.video_fps = config.data.video_fps
self.audio_samples_length = int(
config.data.audio_sample_rate // config.data.video_fps * config.data.num_frames
)
self.image_processor = ImageProcessor(resolution=config.data.resolution, mask="half")
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
def __len__(self):
return len(self.video_paths)
def read_audio(self, video_path: str):
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
return torch.from_numpy(original_mel)
def crop_audio_window(self, original_mel, start_index):
start_idx = int(80.0 * (start_index / float(self.video_fps)))
end_idx = start_idx + self.mel_window_length
return original_mel[:, start_idx:end_idx].unsqueeze(0)
def get_frames(self, video_reader: VideoReader):
total_num_frames = len(video_reader)
start_idx = random.randint(0, total_num_frames - self.num_frames)
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
while True:
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
# wrong_start_idx = random.randint(
# max(0, start_idx - 25), min(total_num_frames - self.num_frames, start_idx + 25)
# )
if wrong_start_idx == start_idx:
continue
# if wrong_start_idx >= start_idx - self.num_frames and wrong_start_idx <= start_idx + self.num_frames:
# continue
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
break
frames = video_reader.get_batch(frames_index).asnumpy()
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
return frames, wrong_frames, start_idx
def worker_init_fn(self, worker_id):
# Initialize the face mesh object in each worker process,
# because the face mesh object cannot be called in subprocesses
self.worker_id = worker_id
# setattr(self, f"image_processor_{worker_id}", ImageProcessor(self.resolution, self.mask))
def __getitem__(self, idx):
# image_processor = getattr(self, f"image_processor_{self.worker_id}")
while True:
try:
idx = random.randint(0, len(self) - 1)
# Get video file path
video_path = self.video_paths[idx]
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
if len(vr) < 2 * self.num_frames:
continue
frames, wrong_frames, start_idx = self.get_frames(vr)
mel_cache_path = os.path.join(
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
)
if os.path.isfile(mel_cache_path):
try:
original_mel = torch.load(mel_cache_path)
except Exception as e:
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
os.remove(mel_cache_path)
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
else:
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
mel = self.crop_audio_window(original_mel, start_idx)
if mel.shape[-1] != self.mel_window_length:
continue
if random.choice([True, False]):
y = torch.ones(1).float()
chosen_frames = frames
else:
y = torch.zeros(1).float()
chosen_frames = wrong_frames
chosen_frames = self.image_processor.process_images(chosen_frames)
# chosen_frames, _, _ = image_processor.prepare_masks_and_masked_images(
# chosen_frames, affine_transform=True
# )
vr.seek(0) # avoid memory leak
break
except Exception as e: # Handle the exception of face not detcted
print(f"{type(e).__name__} - {e} - {video_path}")
if "vr" in locals():
vr.seek(0) # avoid memory leak
sample = dict(frames=chosen_frames, audio_samples=mel, y=y)
return sample
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment