Commit 5c023842 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 增加LatentSync

parent 822b66ca
Pipeline #2211 canceled with stages
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from omegaconf import OmegaConf
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from latentsync.models.unet import UNet3DConditionModel
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
from diffusers.utils.import_utils import is_xformers_available
from accelerate.utils import set_seed
from latentsync.whisper.audio2feature import Audio2Feature
def main(config, args):
print(f"Input video path: {args.video_path}")
print(f"Input audio path: {args.audio_path}")
print(f"Loaded checkpoint path: {args.inference_ckpt_path}")
scheduler = DDIMScheduler.from_pretrained("configs")
if config.model.cross_attention_dim == 768:
whisper_model_path = "checkpoints/whisper/small.pt"
elif config.model.cross_attention_dim == 384:
whisper_model_path = "checkpoints/whisper/tiny.pt"
else:
raise NotImplementedError("cross_attention_dim must be 768 or 384")
audio_encoder = Audio2Feature(model_path=whisper_model_path, device="cuda", num_frames=config.data.num_frames)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
unet, _ = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
args.inference_ckpt_path, # load checkpoint
device="cpu",
)
unet = unet.to(dtype=torch.float16)
# set xformers
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
pipeline = LipsyncPipeline(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=scheduler,
).to("cuda")
if args.seed != -1:
set_seed(args.seed)
else:
torch.seed()
print(f"Initial seed: {torch.initial_seed()}")
pipeline(
video_path=args.video_path,
audio_path=args.audio_path,
video_out_path=args.video_out_path,
video_mask_path=args.video_out_path.replace(".mp4", "_mask.mp4"),
num_frames=config.data.num_frames,
num_inference_steps=config.run.inference_steps,
guidance_scale=args.guidance_scale,
weight_dtype=torch.float16,
width=config.data.resolution,
height=config.data.resolution,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
parser.add_argument("--inference_ckpt_path", type=str, required=True)
parser.add_argument("--video_path", type=str, required=True)
parser.add_argument("--audio_path", type=str, required=True)
parser.add_argument("--video_out_path", type=str, required=True)
parser.add_argument("--guidance_scale", type=float, default=1.0)
parser.add_argument("--seed", type=int, default=1247)
args = parser.parse_args()
config = OmegaConf.load(args.unet_config_path)
main(config, args)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tqdm.auto import tqdm
import os, argparse, datetime, math
import logging
from omegaconf import OmegaConf
import shutil
from latentsync.data.syncnet_dataset import SyncNetDataset
from latentsync.models.syncnet import SyncNet
from latentsync.models.syncnet_wav2lip import SyncNetWav2Lip
from latentsync.utils.util import gather_loss, plot_loss_chart
from accelerate.utils import set_seed
import torch
from diffusers import AutoencoderKL
from diffusers.utils.logging import get_logger
from einops import rearrange
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from latentsync.utils.util import init_dist, cosine_loss
logger = get_logger(__name__)
def main(config):
# Initialize distributed training
local_rank = init_dist()
global_rank = dist.get_rank()
num_processes = dist.get_world_size()
is_main_process = global_rank == 0
seed = config.run.seed + global_rank
set_seed(seed)
# Logging folder
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
output_dir = os.path.join(config.data.train_output_dir, folder_name)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Handle the output folder creation
if is_main_process:
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{output_dir}/loss_charts", exist_ok=True)
shutil.copy(config.config_path, output_dir)
device = torch.device(local_rank)
if config.data.latent_space:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
vae.requires_grad_(False)
vae.to(device)
else:
vae = None
# Dataset and Dataloader setup
train_dataset = SyncNetDataset(config.data.train_data_dir, config.data.train_fileslist, config)
val_dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
train_distributed_sampler = DistributedSampler(
train_dataset,
num_replicas=num_processes,
rank=global_rank,
shuffle=True,
seed=config.run.seed,
)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=False,
sampler=train_distributed_sampler,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=True,
worker_init_fn=train_dataset.worker_init_fn,
)
num_samples_limit = 640
val_batch_size = min(
num_samples_limit // config.data.num_frames, config.data.batch_size
) # limit batch size to avoid CUDA OOM
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=val_batch_size,
shuffle=False,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=False,
worker_init_fn=val_dataset.worker_init_fn,
)
# Model
syncnet = SyncNet(OmegaConf.to_container(config.model)).to(device)
# syncnet = SyncNetWav2Lip().to(device)
optimizer = torch.optim.AdamW(
list(filter(lambda p: p.requires_grad, syncnet.parameters())), lr=config.optimizer.lr
)
if config.ckpt.resume_ckpt_path != "":
if is_main_process:
logger.info(f"Load checkpoint from: {config.ckpt.resume_ckpt_path}")
ckpt = torch.load(config.ckpt.resume_ckpt_path, map_location=device)
syncnet.load_state_dict(ckpt["state_dict"])
global_step = ckpt["global_step"]
train_step_list = ckpt["train_step_list"]
train_loss_list = ckpt["train_loss_list"]
val_step_list = ckpt["val_step_list"]
val_loss_list = ckpt["val_loss_list"]
else:
global_step = 0
train_step_list = []
train_loss_list = []
val_step_list = []
val_loss_list = []
# DDP wrapper
syncnet = DDP(syncnet, device_ids=[local_rank], output_device=local_rank)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
# validation_steps = int(config.ckpt.save_ckpt_steps // 5)
# validation_steps = 100
if is_main_process:
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {config.data.batch_size * num_processes}")
logger.info(f" Total optimization steps = {config.run.max_train_steps}")
first_epoch = global_step // num_update_steps_per_epoch
num_val_batches = config.data.num_val_samples // (num_processes * config.data.batch_size)
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(0, config.run.max_train_steps), initial=global_step, desc="Steps", disable=not is_main_process
)
# Support mixed-precision training
scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None
for epoch in range(first_epoch, num_train_epochs):
train_dataloader.sampler.set_epoch(epoch)
syncnet.train()
for step, batch in enumerate(train_dataloader):
### >>>> Training >>>> ###
frames = batch["frames"].to(device, dtype=torch.float16)
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
y = batch["y"].to(device, dtype=torch.float32)
if config.data.latent_space:
max_batch_size = (
num_samples_limit // config.data.num_frames
) # due to the limited cuda memory, we split the input frames into parts
if frames.shape[0] > max_batch_size:
assert (
frames.shape[0] % max_batch_size == 0
), f"max_batch_size {max_batch_size} should be divisible by batch_size {frames.shape[0]}"
frames_part_results = []
for i in range(0, frames.shape[0], max_batch_size):
frames_part = frames[i : i + max_batch_size]
frames_part = rearrange(frames_part, "b f c h w -> (b f) c h w")
with torch.no_grad():
frames_part = vae.encode(frames_part).latent_dist.sample() * 0.18215
frames_part_results.append(frames_part)
frames = torch.cat(frames_part_results, dim=0)
else:
frames = rearrange(frames, "b f c h w -> (b f) c h w")
with torch.no_grad():
frames = vae.encode(frames).latent_dist.sample() * 0.18215
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
else:
frames = rearrange(frames, "b f c h w -> b (f c) h w")
if config.data.lower_half:
height = frames.shape[2]
frames = frames[:, :, height // 2 :, :]
# audio_embeds = wav2vec_encoder(audio_samples).last_hidden_state
# Mixed-precision training
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training):
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
optimizer.zero_grad()
# Backpropagate
if config.run.mixed_precision_training:
scaler.scale(loss).backward()
""" >>> gradient clipping >>> """
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
""" >>> gradient clipping >>> """
torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
optimizer.step()
progress_bar.update(1)
global_step += 1
global_average_loss = gather_loss(loss, device)
train_step_list.append(global_step)
train_loss_list.append(global_average_loss)
if is_main_process and global_step % config.run.validation_steps == 0:
logger.info(f"Validation at step {global_step}")
val_loss = validation(
val_dataloader,
device,
syncnet,
cosine_loss,
config.data.latent_space,
config.data.lower_half,
vae,
num_val_batches,
)
val_step_list.append(global_step)
val_loss_list.append(val_loss)
logger.info(f"Validation loss at step {global_step} is {val_loss:0.3f}")
if is_main_process and global_step % config.ckpt.save_ckpt_steps == 0:
checkpoint_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
torch.save(
{
"state_dict": syncnet.module.state_dict(), # to unwrap DDP
"global_step": global_step,
"train_step_list": train_step_list,
"train_loss_list": train_loss_list,
"val_step_list": val_step_list,
"val_loss_list": val_loss_list,
},
checkpoint_save_path,
)
logger.info(f"Saved checkpoint to {checkpoint_save_path}")
plot_loss_chart(
os.path.join(output_dir, f"loss_charts/loss_chart-{global_step}.png"),
("Train loss", train_step_list, train_loss_list),
("Val loss", val_step_list, val_loss_list),
)
progress_bar.set_postfix({"step_loss": global_average_loss})
if global_step >= config.run.max_train_steps:
break
progress_bar.close()
dist.destroy_process_group()
@torch.no_grad()
def validation(val_dataloader, device, syncnet, cosine_loss, latent_space, lower_half, vae, num_val_batches):
syncnet.eval()
losses = []
val_step = 0
while True:
for step, batch in enumerate(val_dataloader):
### >>>> Validation >>>> ###
frames = batch["frames"].to(device, dtype=torch.float16)
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
y = batch["y"].to(device, dtype=torch.float32)
if latent_space:
num_frames = frames.shape[1]
frames = rearrange(frames, "b f c h w -> (b f) c h w")
frames = vae.encode(frames).latent_dist.sample() * 0.18215
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=num_frames)
else:
frames = rearrange(frames, "b f c h w -> b (f c) h w")
if lower_half:
height = frames.shape[2]
frames = frames[:, :, height // 2 :, :]
with torch.autocast(device_type="cuda", dtype=torch.float16):
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
losses.append(loss.item())
val_step += 1
if val_step > num_val_batches:
syncnet.train()
if len(losses) == 0:
raise RuntimeError("No validation data")
return sum(losses) / len(losses)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to train the expert lip-sync discriminator")
parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_vae.yaml")
args = parser.parse_args()
# Load a configuration file
config = OmegaConf.load(args.config_path)
config.config_path = args.config_path
main(config)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import argparse
import shutil
import datetime
import logging
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from einops import rearrange
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import diffusers
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils.logging import get_logger
from diffusers.optimization import get_scheduler
from diffusers.utils.import_utils import is_xformers_available
from accelerate.utils import set_seed
from latentsync.data.unet_dataset import UNetDataset
from latentsync.models.unet import UNet3DConditionModel
from latentsync.models.syncnet import SyncNet
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
from latentsync.utils.util import (
init_dist,
cosine_loss,
reversed_forward,
)
from latentsync.utils.util import plot_loss_chart, gather_loss
from latentsync.whisper.audio2feature import Audio2Feature
from latentsync.trepa import TREPALoss
from eval.syncnet import SyncNetEval
from eval.syncnet_detect import SyncNetDetector
from eval.eval_sync_conf import syncnet_eval
import lpips
logger = get_logger(__name__)
def main(config):
# Initialize distributed training
local_rank = init_dist()
global_rank = dist.get_rank()
num_processes = dist.get_world_size()
is_main_process = global_rank == 0
seed = config.run.seed + global_rank
set_seed(seed)
# Logging folder
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
output_dir = os.path.join(config.data.train_output_dir, folder_name)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Handle the output folder creation
if is_main_process:
diffusers.utils.logging.set_verbosity_info()
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{output_dir}/val_videos", exist_ok=True)
os.makedirs(f"{output_dir}/loss_charts", exist_ok=True)
shutil.copy(config.unet_config_path, output_dir)
shutil.copy(config.data.syncnet_config_path, output_dir)
device = torch.device(local_rank)
noise_scheduler = DDIMScheduler.from_pretrained("configs")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
vae.requires_grad_(False)
vae.to(device)
syncnet_eval_model = SyncNetEval(device=device)
syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model")
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
if config.model.cross_attention_dim == 768:
whisper_model_path = "checkpoints/whisper/small.pt"
elif config.model.cross_attention_dim == 384:
whisper_model_path = "checkpoints/whisper/tiny.pt"
else:
raise NotImplementedError("cross_attention_dim must be 768 or 384")
audio_encoder = Audio2Feature(
model_path=whisper_model_path,
device=device,
audio_embeds_cache_dir=config.data.audio_embeds_cache_dir,
num_frames=config.data.num_frames,
)
unet, resume_global_step = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
config.ckpt.resume_ckpt_path, # load checkpoint
device=device,
)
if config.model.add_audio_layer and config.run.use_syncnet:
syncnet_config = OmegaConf.load(config.data.syncnet_config_path)
if syncnet_config.ckpt.inference_ckpt_path == "":
raise ValueError("SyncNet path is not provided")
syncnet = SyncNet(OmegaConf.to_container(syncnet_config.model)).to(device=device, dtype=torch.float16)
syncnet_checkpoint = torch.load(syncnet_config.ckpt.inference_ckpt_path, map_location=device)
syncnet.load_state_dict(syncnet_checkpoint["state_dict"])
syncnet.requires_grad_(False)
unet.requires_grad_(True)
trainable_params = list(unet.parameters())
if config.optimizer.scale_lr:
config.optimizer.lr = config.optimizer.lr * num_processes
optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr)
if is_main_process:
logger.info(f"trainable params number: {len(trainable_params)}")
logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
# Enable xformers
if config.run.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Enable gradient checkpointing
if config.run.enable_gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Get the training dataset
train_dataset = UNetDataset(config.data.train_data_dir, config)
distributed_sampler = DistributedSampler(
train_dataset,
num_replicas=num_processes,
rank=global_rank,
shuffle=True,
seed=config.run.seed,
)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=False,
sampler=distributed_sampler,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=True,
worker_init_fn=train_dataset.worker_init_fn,
)
# Get the training iteration
if config.run.max_train_steps == -1:
assert config.run.max_train_epochs != -1
config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader)
# Scheduler
lr_scheduler = get_scheduler(
config.optimizer.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.optimizer.lr_warmup_steps,
num_training_steps=config.run.max_train_steps,
)
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
lpips_loss_func = lpips.LPIPS(net="vgg").to(device)
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
trepa_loss_func = TREPALoss(device=device)
# Validation pipeline
pipeline = LipsyncPipeline(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=noise_scheduler,
).to(device)
pipeline.set_progress_bar_config(disable=True)
# DDP warpper
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = config.data.batch_size * num_processes
if is_main_process:
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {config.run.max_train_steps}")
global_step = resume_global_step
first_epoch = resume_global_step // num_update_steps_per_epoch
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(0, config.run.max_train_steps),
initial=resume_global_step,
desc="Steps",
disable=not is_main_process,
)
train_step_list = []
sync_loss_list = []
recon_loss_list = []
val_step_list = []
sync_conf_list = []
# Support mixed-precision training
scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None
for epoch in range(first_epoch, num_train_epochs):
train_dataloader.sampler.set_epoch(epoch)
unet.train()
for step, batch in enumerate(train_dataloader):
### >>>> Training >>>> ###
if config.model.add_audio_layer:
if batch["mel"] != []:
mel = batch["mel"].to(device, dtype=torch.float16)
audio_embeds_list = []
try:
for idx in range(len(batch["video_path"])):
video_path = batch["video_path"][idx]
start_idx = batch["start_idx"][idx]
with torch.no_grad():
audio_feat = audio_encoder.audio2feat(video_path)
audio_embeds = audio_encoder.crop_overlap_audio_window(audio_feat, start_idx)
audio_embeds_list.append(audio_embeds)
except Exception as e:
logger.info(f"{type(e).__name__} - {e} - {video_path}")
continue
audio_embeds = torch.stack(audio_embeds_list) # (B, 16, 50, 384)
audio_embeds = audio_embeds.to(device, dtype=torch.float16)
else:
audio_embeds = None
# Convert videos to latent space
gt_images = batch["gt"].to(device, dtype=torch.float16)
gt_masked_images = batch["masked_gt"].to(device, dtype=torch.float16)
mask = batch["mask"].to(device, dtype=torch.float16)
ref_images = batch["ref"].to(device, dtype=torch.float16)
gt_images = rearrange(gt_images, "b f c h w -> (b f) c h w")
gt_masked_images = rearrange(gt_masked_images, "b f c h w -> (b f) c h w")
mask = rearrange(mask, "b f c h w -> (b f) c h w")
ref_images = rearrange(ref_images, "b f c h w -> (b f) c h w")
with torch.no_grad():
gt_latents = vae.encode(gt_images).latent_dist.sample()
gt_masked_images = vae.encode(gt_masked_images).latent_dist.sample()
ref_images = vae.encode(ref_images).latent_dist.sample()
mask = torch.nn.functional.interpolate(mask, size=config.data.resolution // vae_scale_factor)
gt_latents = (
rearrange(gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
) * vae.config.scaling_factor
gt_masked_images = (
rearrange(gt_masked_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
- vae.config.shift_factor
) * vae.config.scaling_factor
ref_images = (
rearrange(ref_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
) * vae.config.scaling_factor
mask = rearrange(mask, "(b f) c h w -> b c f h w", f=config.data.num_frames)
# Sample noise that we'll add to the latents
if config.run.use_mixed_noise:
# Refer to the paper: https://arxiv.org/abs/2305.10474
noise_shared_std_dev = (config.run.mixed_noise_alpha**2 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev
noise_shared = noise_shared[:, :, 0:1].repeat(1, 1, config.data.num_frames, 1, 1)
noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev
noise = noise_ind + noise_shared
else:
noise = torch.randn_like(gt_latents)
noise = noise[:, :, 0:1].repeat(
1, 1, config.data.num_frames, 1, 1
) # Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716
bsz = gt_latents.shape[0]
# Sample a random timestep for each video
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_tensor = noise_scheduler.add_noise(gt_latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
unet_input = torch.cat([noisy_tensor, mask, gt_masked_images, ref_images], dim=1)
# Predict the noise and compute loss
# Mixed-precision training
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training):
pred_noise = unet(unet_input, timesteps, encoder_hidden_states=audio_embeds).sample
if config.run.recon_loss_weight != 0:
recon_loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
else:
recon_loss = 0
pred_latents = reversed_forward(noise_scheduler, pred_noise, timesteps, noisy_tensor)
if config.run.pixel_space_supervise:
pred_images = vae.decode(
rearrange(pred_latents, "b c f h w -> (b f) c h w") / vae.config.scaling_factor
+ vae.config.shift_factor
).sample
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
pred_images_perceptual = pred_images[:, :, pred_images.shape[2] // 2 :, :]
gt_images_perceptual = gt_images[:, :, gt_images.shape[2] // 2 :, :]
lpips_loss = lpips_loss_func(pred_images_perceptual.float(), gt_images_perceptual.float()).mean()
else:
lpips_loss = 0
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
trepa_pred_images = rearrange(pred_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
trepa_gt_images = rearrange(gt_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
trepa_loss = trepa_loss_func(trepa_pred_images, trepa_gt_images)
else:
trepa_loss = 0
if config.model.add_audio_layer and config.run.use_syncnet:
if config.run.pixel_space_supervise:
syncnet_input = rearrange(pred_images, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
else:
syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w")
if syncnet_config.data.lower_half:
height = syncnet_input.shape[2]
syncnet_input = syncnet_input[:, :, height // 2 :, :]
ones_tensor = torch.ones((config.data.batch_size, 1)).float().to(device=device)
vision_embeds, audio_embeds = syncnet(syncnet_input, mel)
sync_loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), ones_tensor).mean()
sync_loss_list.append(gather_loss(sync_loss, device))
else:
sync_loss = 0
loss = (
recon_loss * config.run.recon_loss_weight
+ sync_loss * config.run.sync_loss_weight
+ lpips_loss * config.run.perceptual_loss_weight
+ trepa_loss * config.run.trepa_loss_weight
)
train_step_list.append(global_step)
if config.run.recon_loss_weight != 0:
recon_loss_list.append(gather_loss(recon_loss, device))
optimizer.zero_grad()
# Backpropagate
if config.run.mixed_precision_training:
scaler.scale(loss).backward()
""" >>> gradient clipping >>> """
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
""" >>> gradient clipping >>> """
torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
optimizer.step()
# Check the grad of attn blocks for debugging
# print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].audio_cross_attn.attn.to_q.weight.grad)
lr_scheduler.step()
progress_bar.update(1)
global_step += 1
### <<<< Training <<<< ###
# Save checkpoint and conduct validation
if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0):
if config.run.recon_loss_weight != 0:
plot_loss_chart(
os.path.join(output_dir, f"loss_charts/recon_loss_chart-{global_step}.png"),
("Reconstruction loss", train_step_list, recon_loss_list),
)
if config.model.add_audio_layer:
if sync_loss_list != []:
plot_loss_chart(
os.path.join(output_dir, f"loss_charts/sync_loss_chart-{global_step}.png"),
("Sync loss", train_step_list, sync_loss_list),
)
model_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
state_dict = {
"global_step": global_step,
"state_dict": unet.module.state_dict(), # to unwrap DDP
}
try:
torch.save(state_dict, model_save_path)
logger.info(f"Saved checkpoint to {model_save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
# Validation
logger.info("Running validation... ")
validation_video_out_path = os.path.join(output_dir, f"val_videos/val_video_{global_step}.mp4")
validation_video_mask_path = os.path.join(output_dir, f"val_videos/val_video_mask.mp4")
with torch.autocast(device_type="cuda", dtype=torch.float16):
pipeline(
config.data.val_video_path,
config.data.val_audio_path,
validation_video_out_path,
validation_video_mask_path,
num_frames=config.data.num_frames,
num_inference_steps=config.run.inference_steps,
guidance_scale=config.run.guidance_scale,
weight_dtype=torch.float16,
width=config.data.resolution,
height=config.data.resolution,
mask=config.data.mask,
)
logger.info(f"Saved validation video output to {validation_video_out_path}")
val_step_list.append(global_step)
if config.model.add_audio_layer:
try:
_, conf = syncnet_eval(syncnet_eval_model, syncnet_detector, validation_video_out_path, "temp")
except Exception as e:
logger.info(e)
conf = 0
sync_conf_list.append(conf)
plot_loss_chart(
os.path.join(output_dir, f"loss_charts/sync_conf_chart-{global_step}.png"),
("Sync confidence", val_step_list, sync_conf_list),
)
logs = {"step_loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= config.run.max_train_steps:
break
progress_bar.close()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Config file path
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
args = parser.parse_args()
config = OmegaConf.load(args.unet_config_path)
config.unet_config_path = args.unet_config_path
main(config)
#!/bin/bash
# Create a new conda environment
conda create -y -n latentsync python=3.10.13
conda activate latentsync
# Install ffmpeg
conda install -y -c conda-forge ffmpeg
# Python dependencies
pip install -r requirements.txt
# OpenCV dependencies
sudo apt -y install libgl1
# Download all the checkpoints from HuggingFace
huggingface-cli download chunyu-li/LatentSync --local-dir checkpoints --exclude "*.git*" "README.md"
# Soft links for the auxiliary models
mkdir -p ~/.cache/torch/hub/checkpoints
ln -s $(pwd)/checkpoints/auxiliary/2DFAN4-cd938726ad.zip ~/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip
ln -s $(pwd)/checkpoints/auxiliary/s3fd-619a316812.pth ~/.cache/torch/hub/checkpoints/s3fd-619a316812.pth
ln -s $(pwd)/checkpoints/auxiliary/vgg16-397923af.pth ~/.cache/torch/hub/checkpoints/vgg16-397923af.pth
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib.pyplot as plt
from latentsync.utils.util import count_video_time, gather_video_paths_recursively
from tqdm import tqdm
def plot_histogram(data, fig_path):
# Create histogram
plt.hist(data, bins=30, edgecolor="black")
# Add titles and labels
plt.title("Histogram of Data Distribution")
plt.xlabel("Video time")
plt.ylabel("Frequency")
# Save plot as an image file
plt.savefig(fig_path) # Save as PNG file. You can also use 'histogram.jpg', 'histogram.pdf', etc.
def main(input_dir, fig_path):
video_paths = gather_video_paths_recursively(input_dir)
video_times = []
for video_path in tqdm(video_paths):
video_times.append(count_video_time(video_path))
plot_histogram(video_times, fig_path)
if __name__ == "__main__":
input_dir = "validation"
fig_path = "histogram.png"
main(input_dir, fig_path)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
from tqdm import tqdm
"""
To use this python file, first install yt-dlp by:
pip install yt-dlp==2024.5.27
"""
def download_video(video_url, video_path):
get_video_channel_command = f"yt-dlp --print channel {video_url}"
result = subprocess.run(get_video_channel_command, shell=True, capture_output=True, text=True)
channel = result.stdout.strip()
if channel in unwanted_channels:
return
download_video_command = f"yt-dlp -f bestvideo+bestaudio --skip-unavailable-fragments --merge-output-format mp4 '{video_url}' --output '{video_path}' --external-downloader aria2c --external-downloader-args '-x 16 -k 1M'"
try:
subprocess.run(download_video_command, shell=True) # ignore_security_alert_wait_for_fix RCE
except KeyboardInterrupt:
print("Stopped")
exit()
except:
print(f"Error downloading video {video_url}")
def download_videos(num_workers, video_urls, video_paths):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
executor.map(download_video, video_urls, video_paths)
def read_video_urls(csv_file_path: str, language_column, video_url_column):
video_urls = []
print("Reading video urls...")
df = pd.read_csv(csv_file_path, sep=",")
for row in tqdm(df.itertuples(), total=len(df)):
language = getattr(row, language_column)
video_url = getattr(row, video_url_column)
if "clip" in video_url:
continue
video_urls.append((language, video_url))
return video_urls
def extract_vid(video_url):
if "watch?v=" in video_url: # ignore_security_alert_wait_for_fix RCE
return video_url.split("watch?v=")[1][:11]
elif "shorts/" in video_url:
return video_url.split("shorts/")[1][:11]
elif "youtu.be/" in video_url:
return video_url.split("youtu.be/")[1][:11]
elif "&v=" in video_url:
return video_url.split("&v=")[1][:11]
else:
print(f"Invalid video url: {video_url}")
return None
def main(csv_file_path, language_column, video_url_column, output_dir, num_workers):
os.makedirs(output_dir, exist_ok=True)
all_video_urls = read_video_urls(csv_file_path, language_column, video_url_column)
video_paths = []
video_urls = []
print("Extracting vid...")
for language, video_url in tqdm(all_video_urls):
vid = extract_vid(video_url)
if vid is None:
continue
video_path = os.path.join(output_dir, language.lower(), f"vid_{vid}.mp4")
if os.path.isfile(video_path):
continue
os.makedirs(os.path.dirname(video_path), exist_ok=True)
video_paths.append(video_path)
video_urls.append(video_url)
if len(video_paths) == 0:
print("All videos have been downloaded")
exit()
else:
print(f"Downloading {len(video_paths)} videos")
download_videos(num_workers, video_urls, video_paths)
if __name__ == "__main__":
csv_file_path = "dcc.csv"
language_column = "video_language"
video_url_column = "video_link"
output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual/raw"
num_workers = 50
unwanted_channels = ["TEDx Talks", "DaePyeong Mukbang", "Joeman"]
main(csv_file_path, language_column, video_url_column, output_dir, num_workers)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from tqdm import tqdm
paths = []
def gather_paths(input_dir, output_dir):
os.makedirs(output_dir, exist_ok=True)
for video in sorted(os.listdir(input_dir)):
if video.endswith(".mp4"):
video_input = os.path.join(input_dir, video)
video_output = os.path.join(output_dir, video)
if os.path.isfile(video_output):
continue
paths.append([video_input, output_dir])
elif os.path.isdir(os.path.join(input_dir, video)):
gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video))
def main(input_dir, output_dir):
print(f"Recursively gathering video paths of {input_dir} ...")
gather_paths(input_dir, output_dir)
for video_input, output_dir in tqdm(paths):
shutil.move(video_input, output_dir)
if __name__ == "__main__":
input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual_dcc"
output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual"
main(input_dir, output_dir)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import os
import torch.multiprocessing as mp
import time
def check_mem(cuda_device):
devices_info = (
os.popen('"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader')
.read()
.strip()
.split("\n")
)
total, used = devices_info[int(cuda_device)].split(",")
return total, used
def loop(cuda_device):
cuda_i = torch.device(f"cuda:{cuda_device}")
total, used = check_mem(cuda_device)
total = int(total)
used = int(used)
max_mem = int(total * 0.9)
block_mem = max_mem - used
while True:
x = torch.rand(20, 512, 512, dtype=torch.float, device=cuda_i)
y = torch.rand(20, 512, 512, dtype=torch.float, device=cuda_i)
time.sleep(0.001)
x = torch.matmul(x, y)
def main():
if torch.cuda.is_available():
num_processes = torch.cuda.device_count()
processes = list()
for i in range(num_processes):
p = mp.Process(target=loop, args=(i,))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
main()
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
def remove_outdated_files(input_dir, begin_date, end_date):
# Remove files from a specific time period
for subdir in os.listdir(input_dir):
if subdir >= begin_date and subdir <= end_date:
subdir_path = os.path.join(input_dir, subdir)
command = f"rm -rf {subdir_path}"
subprocess.run(command, shell=True)
print(f"Deleted: {subdir_path}")
if __name__ == "__main__":
input_dir = "/mnt/bn/video-datasets/output/syncnet"
begin_date = "train-2024_06_19-16:25:44"
end_date = "train-2024_08_03-07:39:58"
remove_outdated_files(input_dir, begin_date, end_date)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tqdm import tqdm
from latentsync.utils.util import gather_video_paths_recursively
def write_fileslist(fileslist_path):
with open(fileslist_path, "w") as _:
pass
def append_fileslist(fileslist_path, video_paths):
with open(fileslist_path, "a") as f:
for video_path in tqdm(video_paths):
f.write(f"{video_path}\n")
def process_input_dir(fileslist_path, input_dir):
print(f"Processing input dir: {input_dir}")
video_paths = gather_video_paths_recursively(input_dir)
append_fileslist(fileslist_path, video_paths)
if __name__ == "__main__":
fileslist_path = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt"
write_fileslist(fileslist_path)
process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/train")
process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/high_visual_quality/train")
process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/avatars/high_visual_quality/train")
process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual/high_visual_quality")
process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/celebv_text/high_visual_quality/train")
process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/youtube/high_visual_quality")
#!/bin/bash
torchrun --nnodes=1 --nproc_per_node=1 --master_port=25678 -m scripts.train_syncnet \
--config_path "configs/syncnet/syncnet_16_pixel.yaml"
#!/bin/bash
torchrun --nnodes=1 --nproc_per_node=1 --master_port=25678 -m scripts.train_unet \
--unet_config_path "configs/unet/first_stage.yaml"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment