Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
LatentSync
Commits
5c023842
Commit
5c023842
authored
Jan 14, 2025
by
chenpangpang
Browse files
feat: 增加LatentSync
parent
822b66ca
Pipeline
#2211
canceled with stages
Changes
112
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1321 additions
and
0 deletions
+1321
-0
LatentSync/scripts/inference.py
LatentSync/scripts/inference.py
+99
-0
LatentSync/scripts/train_syncnet.py
LatentSync/scripts/train_syncnet.py
+336
-0
LatentSync/scripts/train_unet.py
LatentSync/scripts/train_unet.py
+510
-0
LatentSync/setup_env.sh
LatentSync/setup_env.sh
+23
-0
LatentSync/tools/count_videos_time.py
LatentSync/tools/count_videos_time.py
+45
-0
LatentSync/tools/download_youtube_videos.py
LatentSync/tools/download_youtube_videos.py
+113
-0
LatentSync/tools/move_files_recur.py
LatentSync/tools/move_files_recur.py
+48
-0
LatentSync/tools/occupy_gpu.py
LatentSync/tools/occupy_gpu.py
+60
-0
LatentSync/tools/remove_outdated_files.py
LatentSync/tools/remove_outdated_files.py
+34
-0
LatentSync/tools/write_fileslist.py
LatentSync/tools/write_fileslist.py
+45
-0
LatentSync/train_syncnet.sh
LatentSync/train_syncnet.sh
+4
-0
LatentSync/train_unet.sh
LatentSync/train_unet.sh
+4
-0
No files found.
LatentSync/scripts/inference.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
omegaconf
import
OmegaConf
import
torch
from
diffusers
import
AutoencoderKL
,
DDIMScheduler
from
latentsync.models.unet
import
UNet3DConditionModel
from
latentsync.pipelines.lipsync_pipeline
import
LipsyncPipeline
from
diffusers.utils.import_utils
import
is_xformers_available
from
accelerate.utils
import
set_seed
from
latentsync.whisper.audio2feature
import
Audio2Feature
def
main
(
config
,
args
):
print
(
f
"Input video path:
{
args
.
video_path
}
"
)
print
(
f
"Input audio path:
{
args
.
audio_path
}
"
)
print
(
f
"Loaded checkpoint path:
{
args
.
inference_ckpt_path
}
"
)
scheduler
=
DDIMScheduler
.
from_pretrained
(
"configs"
)
if
config
.
model
.
cross_attention_dim
==
768
:
whisper_model_path
=
"checkpoints/whisper/small.pt"
elif
config
.
model
.
cross_attention_dim
==
384
:
whisper_model_path
=
"checkpoints/whisper/tiny.pt"
else
:
raise
NotImplementedError
(
"cross_attention_dim must be 768 or 384"
)
audio_encoder
=
Audio2Feature
(
model_path
=
whisper_model_path
,
device
=
"cuda"
,
num_frames
=
config
.
data
.
num_frames
)
vae
=
AutoencoderKL
.
from_pretrained
(
"stabilityai/sd-vae-ft-mse"
,
torch_dtype
=
torch
.
float16
)
vae
.
config
.
scaling_factor
=
0.18215
vae
.
config
.
shift_factor
=
0
unet
,
_
=
UNet3DConditionModel
.
from_pretrained
(
OmegaConf
.
to_container
(
config
.
model
),
args
.
inference_ckpt_path
,
# load checkpoint
device
=
"cpu"
,
)
unet
=
unet
.
to
(
dtype
=
torch
.
float16
)
# set xformers
if
is_xformers_available
():
unet
.
enable_xformers_memory_efficient_attention
()
pipeline
=
LipsyncPipeline
(
vae
=
vae
,
audio_encoder
=
audio_encoder
,
unet
=
unet
,
scheduler
=
scheduler
,
).
to
(
"cuda"
)
if
args
.
seed
!=
-
1
:
set_seed
(
args
.
seed
)
else
:
torch
.
seed
()
print
(
f
"Initial seed:
{
torch
.
initial_seed
()
}
"
)
pipeline
(
video_path
=
args
.
video_path
,
audio_path
=
args
.
audio_path
,
video_out_path
=
args
.
video_out_path
,
video_mask_path
=
args
.
video_out_path
.
replace
(
".mp4"
,
"_mask.mp4"
),
num_frames
=
config
.
data
.
num_frames
,
num_inference_steps
=
config
.
run
.
inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
weight_dtype
=
torch
.
float16
,
width
=
config
.
data
.
resolution
,
height
=
config
.
data
.
resolution
,
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--unet_config_path"
,
type
=
str
,
default
=
"configs/unet.yaml"
)
parser
.
add_argument
(
"--inference_ckpt_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--video_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--audio_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--video_out_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--guidance_scale"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1247
)
args
=
parser
.
parse_args
()
config
=
OmegaConf
.
load
(
args
.
unet_config_path
)
main
(
config
,
args
)
LatentSync/scripts/train_syncnet.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
tqdm.auto
import
tqdm
import
os
,
argparse
,
datetime
,
math
import
logging
from
omegaconf
import
OmegaConf
import
shutil
from
latentsync.data.syncnet_dataset
import
SyncNetDataset
from
latentsync.models.syncnet
import
SyncNet
from
latentsync.models.syncnet_wav2lip
import
SyncNetWav2Lip
from
latentsync.utils.util
import
gather_loss
,
plot_loss_chart
from
accelerate.utils
import
set_seed
import
torch
from
diffusers
import
AutoencoderKL
from
diffusers.utils.logging
import
get_logger
from
einops
import
rearrange
import
torch.distributed
as
dist
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.utils.data.distributed
import
DistributedSampler
from
latentsync.utils.util
import
init_dist
,
cosine_loss
logger
=
get_logger
(
__name__
)
def
main
(
config
):
# Initialize distributed training
local_rank
=
init_dist
()
global_rank
=
dist
.
get_rank
()
num_processes
=
dist
.
get_world_size
()
is_main_process
=
global_rank
==
0
seed
=
config
.
run
.
seed
+
global_rank
set_seed
(
seed
)
# Logging folder
folder_name
=
"train"
+
datetime
.
datetime
.
now
().
strftime
(
f
"-%Y_%m_%d-%H:%M:%S"
)
output_dir
=
os
.
path
.
join
(
config
.
data
.
train_output_dir
,
folder_name
)
# Make one log on every process with the configuration for debugging.
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
)
# Handle the output folder creation
if
is_main_process
:
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
f
"
{
output_dir
}
/checkpoints"
,
exist_ok
=
True
)
os
.
makedirs
(
f
"
{
output_dir
}
/loss_charts"
,
exist_ok
=
True
)
shutil
.
copy
(
config
.
config_path
,
output_dir
)
device
=
torch
.
device
(
local_rank
)
if
config
.
data
.
latent_space
:
vae
=
AutoencoderKL
.
from_pretrained
(
"stabilityai/sd-vae-ft-mse"
,
torch_dtype
=
torch
.
float16
)
vae
.
requires_grad_
(
False
)
vae
.
to
(
device
)
else
:
vae
=
None
# Dataset and Dataloader setup
train_dataset
=
SyncNetDataset
(
config
.
data
.
train_data_dir
,
config
.
data
.
train_fileslist
,
config
)
val_dataset
=
SyncNetDataset
(
config
.
data
.
val_data_dir
,
config
.
data
.
val_fileslist
,
config
)
train_distributed_sampler
=
DistributedSampler
(
train_dataset
,
num_replicas
=
num_processes
,
rank
=
global_rank
,
shuffle
=
True
,
seed
=
config
.
run
.
seed
,
)
# DataLoaders creation:
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
config
.
data
.
batch_size
,
shuffle
=
False
,
sampler
=
train_distributed_sampler
,
num_workers
=
config
.
data
.
num_workers
,
pin_memory
=
False
,
drop_last
=
True
,
worker_init_fn
=
train_dataset
.
worker_init_fn
,
)
num_samples_limit
=
640
val_batch_size
=
min
(
num_samples_limit
//
config
.
data
.
num_frames
,
config
.
data
.
batch_size
)
# limit batch size to avoid CUDA OOM
val_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
batch_size
=
val_batch_size
,
shuffle
=
False
,
num_workers
=
config
.
data
.
num_workers
,
pin_memory
=
False
,
drop_last
=
False
,
worker_init_fn
=
val_dataset
.
worker_init_fn
,
)
# Model
syncnet
=
SyncNet
(
OmegaConf
.
to_container
(
config
.
model
)).
to
(
device
)
# syncnet = SyncNetWav2Lip().to(device)
optimizer
=
torch
.
optim
.
AdamW
(
list
(
filter
(
lambda
p
:
p
.
requires_grad
,
syncnet
.
parameters
())),
lr
=
config
.
optimizer
.
lr
)
if
config
.
ckpt
.
resume_ckpt_path
!=
""
:
if
is_main_process
:
logger
.
info
(
f
"Load checkpoint from:
{
config
.
ckpt
.
resume_ckpt_path
}
"
)
ckpt
=
torch
.
load
(
config
.
ckpt
.
resume_ckpt_path
,
map_location
=
device
)
syncnet
.
load_state_dict
(
ckpt
[
"state_dict"
])
global_step
=
ckpt
[
"global_step"
]
train_step_list
=
ckpt
[
"train_step_list"
]
train_loss_list
=
ckpt
[
"train_loss_list"
]
val_step_list
=
ckpt
[
"val_step_list"
]
val_loss_list
=
ckpt
[
"val_loss_list"
]
else
:
global_step
=
0
train_step_list
=
[]
train_loss_list
=
[]
val_step_list
=
[]
val_loss_list
=
[]
# DDP wrapper
syncnet
=
DDP
(
syncnet
,
device_ids
=
[
local_rank
],
output_device
=
local_rank
)
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
))
num_train_epochs
=
math
.
ceil
(
config
.
run
.
max_train_steps
/
num_update_steps_per_epoch
)
# validation_steps = int(config.ckpt.save_ckpt_steps // 5)
# validation_steps = 100
if
is_main_process
:
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
)
logger
.
info
(
f
" Num Epochs =
{
num_train_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
config
.
data
.
batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel & distributed) =
{
config
.
data
.
batch_size
*
num_processes
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
config
.
run
.
max_train_steps
}
"
)
first_epoch
=
global_step
//
num_update_steps_per_epoch
num_val_batches
=
config
.
data
.
num_val_samples
//
(
num_processes
*
config
.
data
.
batch_size
)
# Only show the progress bar once on each machine.
progress_bar
=
tqdm
(
range
(
0
,
config
.
run
.
max_train_steps
),
initial
=
global_step
,
desc
=
"Steps"
,
disable
=
not
is_main_process
)
# Support mixed-precision training
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
if
config
.
run
.
mixed_precision_training
else
None
for
epoch
in
range
(
first_epoch
,
num_train_epochs
):
train_dataloader
.
sampler
.
set_epoch
(
epoch
)
syncnet
.
train
()
for
step
,
batch
in
enumerate
(
train_dataloader
):
### >>>> Training >>>> ###
frames
=
batch
[
"frames"
].
to
(
device
,
dtype
=
torch
.
float16
)
audio_samples
=
batch
[
"audio_samples"
].
to
(
device
,
dtype
=
torch
.
float16
)
y
=
batch
[
"y"
].
to
(
device
,
dtype
=
torch
.
float32
)
if
config
.
data
.
latent_space
:
max_batch_size
=
(
num_samples_limit
//
config
.
data
.
num_frames
)
# due to the limited cuda memory, we split the input frames into parts
if
frames
.
shape
[
0
]
>
max_batch_size
:
assert
(
frames
.
shape
[
0
]
%
max_batch_size
==
0
),
f
"max_batch_size
{
max_batch_size
}
should be divisible by batch_size
{
frames
.
shape
[
0
]
}
"
frames_part_results
=
[]
for
i
in
range
(
0
,
frames
.
shape
[
0
],
max_batch_size
):
frames_part
=
frames
[
i
:
i
+
max_batch_size
]
frames_part
=
rearrange
(
frames_part
,
"b f c h w -> (b f) c h w"
)
with
torch
.
no_grad
():
frames_part
=
vae
.
encode
(
frames_part
).
latent_dist
.
sample
()
*
0.18215
frames_part_results
.
append
(
frames_part
)
frames
=
torch
.
cat
(
frames_part_results
,
dim
=
0
)
else
:
frames
=
rearrange
(
frames
,
"b f c h w -> (b f) c h w"
)
with
torch
.
no_grad
():
frames
=
vae
.
encode
(
frames
).
latent_dist
.
sample
()
*
0.18215
frames
=
rearrange
(
frames
,
"(b f) c h w -> b (f c) h w"
,
f
=
config
.
data
.
num_frames
)
else
:
frames
=
rearrange
(
frames
,
"b f c h w -> b (f c) h w"
)
if
config
.
data
.
lower_half
:
height
=
frames
.
shape
[
2
]
frames
=
frames
[:,
:,
height
//
2
:,
:]
# audio_embeds = wav2vec_encoder(audio_samples).last_hidden_state
# Mixed-precision training
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
torch
.
float16
,
enabled
=
config
.
run
.
mixed_precision_training
):
vision_embeds
,
audio_embeds
=
syncnet
(
frames
,
audio_samples
)
loss
=
cosine_loss
(
vision_embeds
.
float
(),
audio_embeds
.
float
(),
y
).
mean
()
optimizer
.
zero_grad
()
# Backpropagate
if
config
.
run
.
mixed_precision_training
:
scaler
.
scale
(
loss
).
backward
()
""" >>> gradient clipping >>> """
scaler
.
unscale_
(
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
syncnet
.
parameters
(),
config
.
optimizer
.
max_grad_norm
)
""" <<< gradient clipping <<< """
scaler
.
step
(
optimizer
)
scaler
.
update
()
else
:
loss
.
backward
()
""" >>> gradient clipping >>> """
torch
.
nn
.
utils
.
clip_grad_norm_
(
syncnet
.
parameters
(),
config
.
optimizer
.
max_grad_norm
)
""" <<< gradient clipping <<< """
optimizer
.
step
()
progress_bar
.
update
(
1
)
global_step
+=
1
global_average_loss
=
gather_loss
(
loss
,
device
)
train_step_list
.
append
(
global_step
)
train_loss_list
.
append
(
global_average_loss
)
if
is_main_process
and
global_step
%
config
.
run
.
validation_steps
==
0
:
logger
.
info
(
f
"Validation at step
{
global_step
}
"
)
val_loss
=
validation
(
val_dataloader
,
device
,
syncnet
,
cosine_loss
,
config
.
data
.
latent_space
,
config
.
data
.
lower_half
,
vae
,
num_val_batches
,
)
val_step_list
.
append
(
global_step
)
val_loss_list
.
append
(
val_loss
)
logger
.
info
(
f
"Validation loss at step
{
global_step
}
is
{
val_loss
:
0.3
f
}
"
)
if
is_main_process
and
global_step
%
config
.
ckpt
.
save_ckpt_steps
==
0
:
checkpoint_save_path
=
os
.
path
.
join
(
output_dir
,
f
"checkpoints/checkpoint-
{
global_step
}
.pt"
)
torch
.
save
(
{
"state_dict"
:
syncnet
.
module
.
state_dict
(),
# to unwrap DDP
"global_step"
:
global_step
,
"train_step_list"
:
train_step_list
,
"train_loss_list"
:
train_loss_list
,
"val_step_list"
:
val_step_list
,
"val_loss_list"
:
val_loss_list
,
},
checkpoint_save_path
,
)
logger
.
info
(
f
"Saved checkpoint to
{
checkpoint_save_path
}
"
)
plot_loss_chart
(
os
.
path
.
join
(
output_dir
,
f
"loss_charts/loss_chart-
{
global_step
}
.png"
),
(
"Train loss"
,
train_step_list
,
train_loss_list
),
(
"Val loss"
,
val_step_list
,
val_loss_list
),
)
progress_bar
.
set_postfix
({
"step_loss"
:
global_average_loss
})
if
global_step
>=
config
.
run
.
max_train_steps
:
break
progress_bar
.
close
()
dist
.
destroy_process_group
()
@
torch
.
no_grad
()
def
validation
(
val_dataloader
,
device
,
syncnet
,
cosine_loss
,
latent_space
,
lower_half
,
vae
,
num_val_batches
):
syncnet
.
eval
()
losses
=
[]
val_step
=
0
while
True
:
for
step
,
batch
in
enumerate
(
val_dataloader
):
### >>>> Validation >>>> ###
frames
=
batch
[
"frames"
].
to
(
device
,
dtype
=
torch
.
float16
)
audio_samples
=
batch
[
"audio_samples"
].
to
(
device
,
dtype
=
torch
.
float16
)
y
=
batch
[
"y"
].
to
(
device
,
dtype
=
torch
.
float32
)
if
latent_space
:
num_frames
=
frames
.
shape
[
1
]
frames
=
rearrange
(
frames
,
"b f c h w -> (b f) c h w"
)
frames
=
vae
.
encode
(
frames
).
latent_dist
.
sample
()
*
0.18215
frames
=
rearrange
(
frames
,
"(b f) c h w -> b (f c) h w"
,
f
=
num_frames
)
else
:
frames
=
rearrange
(
frames
,
"b f c h w -> b (f c) h w"
)
if
lower_half
:
height
=
frames
.
shape
[
2
]
frames
=
frames
[:,
:,
height
//
2
:,
:]
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
torch
.
float16
):
vision_embeds
,
audio_embeds
=
syncnet
(
frames
,
audio_samples
)
loss
=
cosine_loss
(
vision_embeds
.
float
(),
audio_embeds
.
float
(),
y
).
mean
()
losses
.
append
(
loss
.
item
())
val_step
+=
1
if
val_step
>
num_val_batches
:
syncnet
.
train
()
if
len
(
losses
)
==
0
:
raise
RuntimeError
(
"No validation data"
)
return
sum
(
losses
)
/
len
(
losses
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Code to train the expert lip-sync discriminator"
)
parser
.
add_argument
(
"--config_path"
,
type
=
str
,
default
=
"configs/syncnet/syncnet_16_vae.yaml"
)
args
=
parser
.
parse_args
()
# Load a configuration file
config
=
OmegaConf
.
load
(
args
.
config_path
)
config
.
config_path
=
args
.
config_path
main
(
config
)
LatentSync/scripts/train_unet.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
math
import
argparse
import
shutil
import
datetime
import
logging
from
omegaconf
import
OmegaConf
from
tqdm.auto
import
tqdm
from
einops
import
rearrange
import
torch
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
diffusers
from
diffusers
import
AutoencoderKL
,
DDIMScheduler
from
diffusers.utils.logging
import
get_logger
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils.import_utils
import
is_xformers_available
from
accelerate.utils
import
set_seed
from
latentsync.data.unet_dataset
import
UNetDataset
from
latentsync.models.unet
import
UNet3DConditionModel
from
latentsync.models.syncnet
import
SyncNet
from
latentsync.pipelines.lipsync_pipeline
import
LipsyncPipeline
from
latentsync.utils.util
import
(
init_dist
,
cosine_loss
,
reversed_forward
,
)
from
latentsync.utils.util
import
plot_loss_chart
,
gather_loss
from
latentsync.whisper.audio2feature
import
Audio2Feature
from
latentsync.trepa
import
TREPALoss
from
eval.syncnet
import
SyncNetEval
from
eval.syncnet_detect
import
SyncNetDetector
from
eval.eval_sync_conf
import
syncnet_eval
import
lpips
logger
=
get_logger
(
__name__
)
def
main
(
config
):
# Initialize distributed training
local_rank
=
init_dist
()
global_rank
=
dist
.
get_rank
()
num_processes
=
dist
.
get_world_size
()
is_main_process
=
global_rank
==
0
seed
=
config
.
run
.
seed
+
global_rank
set_seed
(
seed
)
# Logging folder
folder_name
=
"train"
+
datetime
.
datetime
.
now
().
strftime
(
f
"-%Y_%m_%d-%H:%M:%S"
)
output_dir
=
os
.
path
.
join
(
config
.
data
.
train_output_dir
,
folder_name
)
# Make one log on every process with the configuration for debugging.
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
)
# Handle the output folder creation
if
is_main_process
:
diffusers
.
utils
.
logging
.
set_verbosity_info
()
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
f
"
{
output_dir
}
/checkpoints"
,
exist_ok
=
True
)
os
.
makedirs
(
f
"
{
output_dir
}
/val_videos"
,
exist_ok
=
True
)
os
.
makedirs
(
f
"
{
output_dir
}
/loss_charts"
,
exist_ok
=
True
)
shutil
.
copy
(
config
.
unet_config_path
,
output_dir
)
shutil
.
copy
(
config
.
data
.
syncnet_config_path
,
output_dir
)
device
=
torch
.
device
(
local_rank
)
noise_scheduler
=
DDIMScheduler
.
from_pretrained
(
"configs"
)
vae
=
AutoencoderKL
.
from_pretrained
(
"stabilityai/sd-vae-ft-mse"
,
torch_dtype
=
torch
.
float16
)
vae
.
config
.
scaling_factor
=
0.18215
vae
.
config
.
shift_factor
=
0
vae_scale_factor
=
2
**
(
len
(
vae
.
config
.
block_out_channels
)
-
1
)
vae
.
requires_grad_
(
False
)
vae
.
to
(
device
)
syncnet_eval_model
=
SyncNetEval
(
device
=
device
)
syncnet_eval_model
.
loadParameters
(
"checkpoints/auxiliary/syncnet_v2.model"
)
syncnet_detector
=
SyncNetDetector
(
device
=
device
,
detect_results_dir
=
"detect_results"
)
if
config
.
model
.
cross_attention_dim
==
768
:
whisper_model_path
=
"checkpoints/whisper/small.pt"
elif
config
.
model
.
cross_attention_dim
==
384
:
whisper_model_path
=
"checkpoints/whisper/tiny.pt"
else
:
raise
NotImplementedError
(
"cross_attention_dim must be 768 or 384"
)
audio_encoder
=
Audio2Feature
(
model_path
=
whisper_model_path
,
device
=
device
,
audio_embeds_cache_dir
=
config
.
data
.
audio_embeds_cache_dir
,
num_frames
=
config
.
data
.
num_frames
,
)
unet
,
resume_global_step
=
UNet3DConditionModel
.
from_pretrained
(
OmegaConf
.
to_container
(
config
.
model
),
config
.
ckpt
.
resume_ckpt_path
,
# load checkpoint
device
=
device
,
)
if
config
.
model
.
add_audio_layer
and
config
.
run
.
use_syncnet
:
syncnet_config
=
OmegaConf
.
load
(
config
.
data
.
syncnet_config_path
)
if
syncnet_config
.
ckpt
.
inference_ckpt_path
==
""
:
raise
ValueError
(
"SyncNet path is not provided"
)
syncnet
=
SyncNet
(
OmegaConf
.
to_container
(
syncnet_config
.
model
)).
to
(
device
=
device
,
dtype
=
torch
.
float16
)
syncnet_checkpoint
=
torch
.
load
(
syncnet_config
.
ckpt
.
inference_ckpt_path
,
map_location
=
device
)
syncnet
.
load_state_dict
(
syncnet_checkpoint
[
"state_dict"
])
syncnet
.
requires_grad_
(
False
)
unet
.
requires_grad_
(
True
)
trainable_params
=
list
(
unet
.
parameters
())
if
config
.
optimizer
.
scale_lr
:
config
.
optimizer
.
lr
=
config
.
optimizer
.
lr
*
num_processes
optimizer
=
torch
.
optim
.
AdamW
(
trainable_params
,
lr
=
config
.
optimizer
.
lr
)
if
is_main_process
:
logger
.
info
(
f
"trainable params number:
{
len
(
trainable_params
)
}
"
)
logger
.
info
(
f
"trainable params scale:
{
sum
(
p
.
numel
()
for
p
in
trainable_params
)
/
1e6
:.
3
f
}
M"
)
# Enable xformers
if
config
.
run
.
enable_xformers_memory_efficient_attention
:
if
is_xformers_available
():
unet
.
enable_xformers_memory_efficient_attention
()
else
:
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
# Enable gradient checkpointing
if
config
.
run
.
enable_gradient_checkpointing
:
unet
.
enable_gradient_checkpointing
()
# Get the training dataset
train_dataset
=
UNetDataset
(
config
.
data
.
train_data_dir
,
config
)
distributed_sampler
=
DistributedSampler
(
train_dataset
,
num_replicas
=
num_processes
,
rank
=
global_rank
,
shuffle
=
True
,
seed
=
config
.
run
.
seed
,
)
# DataLoaders creation:
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
config
.
data
.
batch_size
,
shuffle
=
False
,
sampler
=
distributed_sampler
,
num_workers
=
config
.
data
.
num_workers
,
pin_memory
=
False
,
drop_last
=
True
,
worker_init_fn
=
train_dataset
.
worker_init_fn
,
)
# Get the training iteration
if
config
.
run
.
max_train_steps
==
-
1
:
assert
config
.
run
.
max_train_epochs
!=
-
1
config
.
run
.
max_train_steps
=
config
.
run
.
max_train_epochs
*
len
(
train_dataloader
)
# Scheduler
lr_scheduler
=
get_scheduler
(
config
.
optimizer
.
lr_scheduler
,
optimizer
=
optimizer
,
num_warmup_steps
=
config
.
optimizer
.
lr_warmup_steps
,
num_training_steps
=
config
.
run
.
max_train_steps
,
)
if
config
.
run
.
perceptual_loss_weight
!=
0
and
config
.
run
.
pixel_space_supervise
:
lpips_loss_func
=
lpips
.
LPIPS
(
net
=
"vgg"
).
to
(
device
)
if
config
.
run
.
trepa_loss_weight
!=
0
and
config
.
run
.
pixel_space_supervise
:
trepa_loss_func
=
TREPALoss
(
device
=
device
)
# Validation pipeline
pipeline
=
LipsyncPipeline
(
vae
=
vae
,
audio_encoder
=
audio_encoder
,
unet
=
unet
,
scheduler
=
noise_scheduler
,
).
to
(
device
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
# DDP warpper
unet
=
DDP
(
unet
,
device_ids
=
[
local_rank
],
output_device
=
local_rank
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
))
# Afterwards we recalculate our number of training epochs
num_train_epochs
=
math
.
ceil
(
config
.
run
.
max_train_steps
/
num_update_steps_per_epoch
)
# Train!
total_batch_size
=
config
.
data
.
batch_size
*
num_processes
if
is_main_process
:
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataset
)
}
"
)
logger
.
info
(
f
" Num Epochs =
{
num_train_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
config
.
data
.
batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_batch_size
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
config
.
run
.
max_train_steps
}
"
)
global_step
=
resume_global_step
first_epoch
=
resume_global_step
//
num_update_steps_per_epoch
# Only show the progress bar once on each machine.
progress_bar
=
tqdm
(
range
(
0
,
config
.
run
.
max_train_steps
),
initial
=
resume_global_step
,
desc
=
"Steps"
,
disable
=
not
is_main_process
,
)
train_step_list
=
[]
sync_loss_list
=
[]
recon_loss_list
=
[]
val_step_list
=
[]
sync_conf_list
=
[]
# Support mixed-precision training
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
if
config
.
run
.
mixed_precision_training
else
None
for
epoch
in
range
(
first_epoch
,
num_train_epochs
):
train_dataloader
.
sampler
.
set_epoch
(
epoch
)
unet
.
train
()
for
step
,
batch
in
enumerate
(
train_dataloader
):
### >>>> Training >>>> ###
if
config
.
model
.
add_audio_layer
:
if
batch
[
"mel"
]
!=
[]:
mel
=
batch
[
"mel"
].
to
(
device
,
dtype
=
torch
.
float16
)
audio_embeds_list
=
[]
try
:
for
idx
in
range
(
len
(
batch
[
"video_path"
])):
video_path
=
batch
[
"video_path"
][
idx
]
start_idx
=
batch
[
"start_idx"
][
idx
]
with
torch
.
no_grad
():
audio_feat
=
audio_encoder
.
audio2feat
(
video_path
)
audio_embeds
=
audio_encoder
.
crop_overlap_audio_window
(
audio_feat
,
start_idx
)
audio_embeds_list
.
append
(
audio_embeds
)
except
Exception
as
e
:
logger
.
info
(
f
"
{
type
(
e
).
__name__
}
-
{
e
}
-
{
video_path
}
"
)
continue
audio_embeds
=
torch
.
stack
(
audio_embeds_list
)
# (B, 16, 50, 384)
audio_embeds
=
audio_embeds
.
to
(
device
,
dtype
=
torch
.
float16
)
else
:
audio_embeds
=
None
# Convert videos to latent space
gt_images
=
batch
[
"gt"
].
to
(
device
,
dtype
=
torch
.
float16
)
gt_masked_images
=
batch
[
"masked_gt"
].
to
(
device
,
dtype
=
torch
.
float16
)
mask
=
batch
[
"mask"
].
to
(
device
,
dtype
=
torch
.
float16
)
ref_images
=
batch
[
"ref"
].
to
(
device
,
dtype
=
torch
.
float16
)
gt_images
=
rearrange
(
gt_images
,
"b f c h w -> (b f) c h w"
)
gt_masked_images
=
rearrange
(
gt_masked_images
,
"b f c h w -> (b f) c h w"
)
mask
=
rearrange
(
mask
,
"b f c h w -> (b f) c h w"
)
ref_images
=
rearrange
(
ref_images
,
"b f c h w -> (b f) c h w"
)
with
torch
.
no_grad
():
gt_latents
=
vae
.
encode
(
gt_images
).
latent_dist
.
sample
()
gt_masked_images
=
vae
.
encode
(
gt_masked_images
).
latent_dist
.
sample
()
ref_images
=
vae
.
encode
(
ref_images
).
latent_dist
.
sample
()
mask
=
torch
.
nn
.
functional
.
interpolate
(
mask
,
size
=
config
.
data
.
resolution
//
vae_scale_factor
)
gt_latents
=
(
rearrange
(
gt_latents
,
"(b f) c h w -> b c f h w"
,
f
=
config
.
data
.
num_frames
)
-
vae
.
config
.
shift_factor
)
*
vae
.
config
.
scaling_factor
gt_masked_images
=
(
rearrange
(
gt_masked_images
,
"(b f) c h w -> b c f h w"
,
f
=
config
.
data
.
num_frames
)
-
vae
.
config
.
shift_factor
)
*
vae
.
config
.
scaling_factor
ref_images
=
(
rearrange
(
ref_images
,
"(b f) c h w -> b c f h w"
,
f
=
config
.
data
.
num_frames
)
-
vae
.
config
.
shift_factor
)
*
vae
.
config
.
scaling_factor
mask
=
rearrange
(
mask
,
"(b f) c h w -> b c f h w"
,
f
=
config
.
data
.
num_frames
)
# Sample noise that we'll add to the latents
if
config
.
run
.
use_mixed_noise
:
# Refer to the paper: https://arxiv.org/abs/2305.10474
noise_shared_std_dev
=
(
config
.
run
.
mixed_noise_alpha
**
2
/
(
1
+
config
.
run
.
mixed_noise_alpha
**
2
))
**
0.5
noise_shared
=
torch
.
randn_like
(
gt_latents
)
*
noise_shared_std_dev
noise_shared
=
noise_shared
[:,
:,
0
:
1
].
repeat
(
1
,
1
,
config
.
data
.
num_frames
,
1
,
1
)
noise_ind_std_dev
=
(
1
/
(
1
+
config
.
run
.
mixed_noise_alpha
**
2
))
**
0.5
noise_ind
=
torch
.
randn_like
(
gt_latents
)
*
noise_ind_std_dev
noise
=
noise_ind
+
noise_shared
else
:
noise
=
torch
.
randn_like
(
gt_latents
)
noise
=
noise
[:,
:,
0
:
1
].
repeat
(
1
,
1
,
config
.
data
.
num_frames
,
1
,
1
)
# Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716
bsz
=
gt_latents
.
shape
[
0
]
# Sample a random timestep for each video
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
config
.
num_train_timesteps
,
(
bsz
,),
device
=
gt_latents
.
device
)
timesteps
=
timesteps
.
long
()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_tensor
=
noise_scheduler
.
add_noise
(
gt_latents
,
noise
,
timesteps
)
# Get the target for loss depending on the prediction type
if
noise_scheduler
.
config
.
prediction_type
==
"epsilon"
:
target
=
noise
elif
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
raise
NotImplementedError
else
:
raise
ValueError
(
f
"Unknown prediction type
{
noise_scheduler
.
config
.
prediction_type
}
"
)
unet_input
=
torch
.
cat
([
noisy_tensor
,
mask
,
gt_masked_images
,
ref_images
],
dim
=
1
)
# Predict the noise and compute loss
# Mixed-precision training
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
torch
.
float16
,
enabled
=
config
.
run
.
mixed_precision_training
):
pred_noise
=
unet
(
unet_input
,
timesteps
,
encoder_hidden_states
=
audio_embeds
).
sample
if
config
.
run
.
recon_loss_weight
!=
0
:
recon_loss
=
F
.
mse_loss
(
pred_noise
.
float
(),
target
.
float
(),
reduction
=
"mean"
)
else
:
recon_loss
=
0
pred_latents
=
reversed_forward
(
noise_scheduler
,
pred_noise
,
timesteps
,
noisy_tensor
)
if
config
.
run
.
pixel_space_supervise
:
pred_images
=
vae
.
decode
(
rearrange
(
pred_latents
,
"b c f h w -> (b f) c h w"
)
/
vae
.
config
.
scaling_factor
+
vae
.
config
.
shift_factor
).
sample
if
config
.
run
.
perceptual_loss_weight
!=
0
and
config
.
run
.
pixel_space_supervise
:
pred_images_perceptual
=
pred_images
[:,
:,
pred_images
.
shape
[
2
]
//
2
:,
:]
gt_images_perceptual
=
gt_images
[:,
:,
gt_images
.
shape
[
2
]
//
2
:,
:]
lpips_loss
=
lpips_loss_func
(
pred_images_perceptual
.
float
(),
gt_images_perceptual
.
float
()).
mean
()
else
:
lpips_loss
=
0
if
config
.
run
.
trepa_loss_weight
!=
0
and
config
.
run
.
pixel_space_supervise
:
trepa_pred_images
=
rearrange
(
pred_images
,
"(b f) c h w -> b c f h w"
,
f
=
config
.
data
.
num_frames
)
trepa_gt_images
=
rearrange
(
gt_images
,
"(b f) c h w -> b c f h w"
,
f
=
config
.
data
.
num_frames
)
trepa_loss
=
trepa_loss_func
(
trepa_pred_images
,
trepa_gt_images
)
else
:
trepa_loss
=
0
if
config
.
model
.
add_audio_layer
and
config
.
run
.
use_syncnet
:
if
config
.
run
.
pixel_space_supervise
:
syncnet_input
=
rearrange
(
pred_images
,
"(b f) c h w -> b (f c) h w"
,
f
=
config
.
data
.
num_frames
)
else
:
syncnet_input
=
rearrange
(
pred_latents
,
"b c f h w -> b (f c) h w"
)
if
syncnet_config
.
data
.
lower_half
:
height
=
syncnet_input
.
shape
[
2
]
syncnet_input
=
syncnet_input
[:,
:,
height
//
2
:,
:]
ones_tensor
=
torch
.
ones
((
config
.
data
.
batch_size
,
1
)).
float
().
to
(
device
=
device
)
vision_embeds
,
audio_embeds
=
syncnet
(
syncnet_input
,
mel
)
sync_loss
=
cosine_loss
(
vision_embeds
.
float
(),
audio_embeds
.
float
(),
ones_tensor
).
mean
()
sync_loss_list
.
append
(
gather_loss
(
sync_loss
,
device
))
else
:
sync_loss
=
0
loss
=
(
recon_loss
*
config
.
run
.
recon_loss_weight
+
sync_loss
*
config
.
run
.
sync_loss_weight
+
lpips_loss
*
config
.
run
.
perceptual_loss_weight
+
trepa_loss
*
config
.
run
.
trepa_loss_weight
)
train_step_list
.
append
(
global_step
)
if
config
.
run
.
recon_loss_weight
!=
0
:
recon_loss_list
.
append
(
gather_loss
(
recon_loss
,
device
))
optimizer
.
zero_grad
()
# Backpropagate
if
config
.
run
.
mixed_precision_training
:
scaler
.
scale
(
loss
).
backward
()
""" >>> gradient clipping >>> """
scaler
.
unscale_
(
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
unet
.
parameters
(),
config
.
optimizer
.
max_grad_norm
)
""" <<< gradient clipping <<< """
scaler
.
step
(
optimizer
)
scaler
.
update
()
else
:
loss
.
backward
()
""" >>> gradient clipping >>> """
torch
.
nn
.
utils
.
clip_grad_norm_
(
unet
.
parameters
(),
config
.
optimizer
.
max_grad_norm
)
""" <<< gradient clipping <<< """
optimizer
.
step
()
# Check the grad of attn blocks for debugging
# print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].audio_cross_attn.attn.to_q.weight.grad)
lr_scheduler
.
step
()
progress_bar
.
update
(
1
)
global_step
+=
1
### <<<< Training <<<< ###
# Save checkpoint and conduct validation
if
is_main_process
and
(
global_step
%
config
.
ckpt
.
save_ckpt_steps
==
0
):
if
config
.
run
.
recon_loss_weight
!=
0
:
plot_loss_chart
(
os
.
path
.
join
(
output_dir
,
f
"loss_charts/recon_loss_chart-
{
global_step
}
.png"
),
(
"Reconstruction loss"
,
train_step_list
,
recon_loss_list
),
)
if
config
.
model
.
add_audio_layer
:
if
sync_loss_list
!=
[]:
plot_loss_chart
(
os
.
path
.
join
(
output_dir
,
f
"loss_charts/sync_loss_chart-
{
global_step
}
.png"
),
(
"Sync loss"
,
train_step_list
,
sync_loss_list
),
)
model_save_path
=
os
.
path
.
join
(
output_dir
,
f
"checkpoints/checkpoint-
{
global_step
}
.pt"
)
state_dict
=
{
"global_step"
:
global_step
,
"state_dict"
:
unet
.
module
.
state_dict
(),
# to unwrap DDP
}
try
:
torch
.
save
(
state_dict
,
model_save_path
)
logger
.
info
(
f
"Saved checkpoint to
{
model_save_path
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error saving model:
{
e
}
"
)
# Validation
logger
.
info
(
"Running validation... "
)
validation_video_out_path
=
os
.
path
.
join
(
output_dir
,
f
"val_videos/val_video_
{
global_step
}
.mp4"
)
validation_video_mask_path
=
os
.
path
.
join
(
output_dir
,
f
"val_videos/val_video_mask.mp4"
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
torch
.
float16
):
pipeline
(
config
.
data
.
val_video_path
,
config
.
data
.
val_audio_path
,
validation_video_out_path
,
validation_video_mask_path
,
num_frames
=
config
.
data
.
num_frames
,
num_inference_steps
=
config
.
run
.
inference_steps
,
guidance_scale
=
config
.
run
.
guidance_scale
,
weight_dtype
=
torch
.
float16
,
width
=
config
.
data
.
resolution
,
height
=
config
.
data
.
resolution
,
mask
=
config
.
data
.
mask
,
)
logger
.
info
(
f
"Saved validation video output to
{
validation_video_out_path
}
"
)
val_step_list
.
append
(
global_step
)
if
config
.
model
.
add_audio_layer
:
try
:
_
,
conf
=
syncnet_eval
(
syncnet_eval_model
,
syncnet_detector
,
validation_video_out_path
,
"temp"
)
except
Exception
as
e
:
logger
.
info
(
e
)
conf
=
0
sync_conf_list
.
append
(
conf
)
plot_loss_chart
(
os
.
path
.
join
(
output_dir
,
f
"loss_charts/sync_conf_chart-
{
global_step
}
.png"
),
(
"Sync confidence"
,
val_step_list
,
sync_conf_list
),
)
logs
=
{
"step_loss"
:
loss
.
item
(),
"lr"
:
lr_scheduler
.
get_last_lr
()[
0
]}
progress_bar
.
set_postfix
(
**
logs
)
if
global_step
>=
config
.
run
.
max_train_steps
:
break
progress_bar
.
close
()
dist
.
destroy_process_group
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Config file path
parser
.
add_argument
(
"--unet_config_path"
,
type
=
str
,
default
=
"configs/unet.yaml"
)
args
=
parser
.
parse_args
()
config
=
OmegaConf
.
load
(
args
.
unet_config_path
)
config
.
unet_config_path
=
args
.
unet_config_path
main
(
config
)
LatentSync/setup_env.sh
0 → 100644
View file @
5c023842
#!/bin/bash
# Create a new conda environment
conda create
-y
-n
latentsync
python
=
3.10.13
conda activate latentsync
# Install ffmpeg
conda
install
-y
-c
conda-forge ffmpeg
# Python dependencies
pip
install
-r
requirements.txt
# OpenCV dependencies
sudo
apt
-y
install
libgl1
# Download all the checkpoints from HuggingFace
huggingface-cli download chunyu-li/LatentSync
--local-dir
checkpoints
--exclude
"*.git*"
"README.md"
# Soft links for the auxiliary models
mkdir
-p
~/.cache/torch/hub/checkpoints
ln
-s
$(
pwd
)
/checkpoints/auxiliary/2DFAN4-cd938726ad.zip ~/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip
ln
-s
$(
pwd
)
/checkpoints/auxiliary/s3fd-619a316812.pth ~/.cache/torch/hub/checkpoints/s3fd-619a316812.pth
ln
-s
$(
pwd
)
/checkpoints/auxiliary/vgg16-397923af.pth ~/.cache/torch/hub/checkpoints/vgg16-397923af.pth
LatentSync/tools/count_videos_time.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
matplotlib.pyplot
as
plt
from
latentsync.utils.util
import
count_video_time
,
gather_video_paths_recursively
from
tqdm
import
tqdm
def
plot_histogram
(
data
,
fig_path
):
# Create histogram
plt
.
hist
(
data
,
bins
=
30
,
edgecolor
=
"black"
)
# Add titles and labels
plt
.
title
(
"Histogram of Data Distribution"
)
plt
.
xlabel
(
"Video time"
)
plt
.
ylabel
(
"Frequency"
)
# Save plot as an image file
plt
.
savefig
(
fig_path
)
# Save as PNG file. You can also use 'histogram.jpg', 'histogram.pdf', etc.
def
main
(
input_dir
,
fig_path
):
video_paths
=
gather_video_paths_recursively
(
input_dir
)
video_times
=
[]
for
video_path
in
tqdm
(
video_paths
):
video_times
.
append
(
count_video_time
(
video_path
))
plot_histogram
(
video_times
,
fig_path
)
if
__name__
==
"__main__"
:
input_dir
=
"validation"
fig_path
=
"histogram.png"
main
(
input_dir
,
fig_path
)
LatentSync/tools/download_youtube_videos.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
from
concurrent.futures
import
ThreadPoolExecutor
import
pandas
as
pd
from
tqdm
import
tqdm
"""
To use this python file, first install yt-dlp by:
pip install yt-dlp==2024.5.27
"""
def
download_video
(
video_url
,
video_path
):
get_video_channel_command
=
f
"yt-dlp --print channel
{
video_url
}
"
result
=
subprocess
.
run
(
get_video_channel_command
,
shell
=
True
,
capture_output
=
True
,
text
=
True
)
channel
=
result
.
stdout
.
strip
()
if
channel
in
unwanted_channels
:
return
download_video_command
=
f
"yt-dlp -f bestvideo+bestaudio --skip-unavailable-fragments --merge-output-format mp4 '
{
video_url
}
' --output '
{
video_path
}
' --external-downloader aria2c --external-downloader-args '-x 16 -k 1M'"
try
:
subprocess
.
run
(
download_video_command
,
shell
=
True
)
# ignore_security_alert_wait_for_fix RCE
except
KeyboardInterrupt
:
print
(
"Stopped"
)
exit
()
except
:
print
(
f
"Error downloading video
{
video_url
}
"
)
def
download_videos
(
num_workers
,
video_urls
,
video_paths
):
with
ThreadPoolExecutor
(
max_workers
=
num_workers
)
as
executor
:
executor
.
map
(
download_video
,
video_urls
,
video_paths
)
def
read_video_urls
(
csv_file_path
:
str
,
language_column
,
video_url_column
):
video_urls
=
[]
print
(
"Reading video urls..."
)
df
=
pd
.
read_csv
(
csv_file_path
,
sep
=
","
)
for
row
in
tqdm
(
df
.
itertuples
(),
total
=
len
(
df
)):
language
=
getattr
(
row
,
language_column
)
video_url
=
getattr
(
row
,
video_url_column
)
if
"clip"
in
video_url
:
continue
video_urls
.
append
((
language
,
video_url
))
return
video_urls
def
extract_vid
(
video_url
):
if
"watch?v="
in
video_url
:
# ignore_security_alert_wait_for_fix RCE
return
video_url
.
split
(
"watch?v="
)[
1
][:
11
]
elif
"shorts/"
in
video_url
:
return
video_url
.
split
(
"shorts/"
)[
1
][:
11
]
elif
"youtu.be/"
in
video_url
:
return
video_url
.
split
(
"youtu.be/"
)[
1
][:
11
]
elif
"&v="
in
video_url
:
return
video_url
.
split
(
"&v="
)[
1
][:
11
]
else
:
print
(
f
"Invalid video url:
{
video_url
}
"
)
return
None
def
main
(
csv_file_path
,
language_column
,
video_url_column
,
output_dir
,
num_workers
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
all_video_urls
=
read_video_urls
(
csv_file_path
,
language_column
,
video_url_column
)
video_paths
=
[]
video_urls
=
[]
print
(
"Extracting vid..."
)
for
language
,
video_url
in
tqdm
(
all_video_urls
):
vid
=
extract_vid
(
video_url
)
if
vid
is
None
:
continue
video_path
=
os
.
path
.
join
(
output_dir
,
language
.
lower
(),
f
"vid_
{
vid
}
.mp4"
)
if
os
.
path
.
isfile
(
video_path
):
continue
os
.
makedirs
(
os
.
path
.
dirname
(
video_path
),
exist_ok
=
True
)
video_paths
.
append
(
video_path
)
video_urls
.
append
(
video_url
)
if
len
(
video_paths
)
==
0
:
print
(
"All videos have been downloaded"
)
exit
()
else
:
print
(
f
"Downloading
{
len
(
video_paths
)
}
videos"
)
download_videos
(
num_workers
,
video_urls
,
video_paths
)
if
__name__
==
"__main__"
:
csv_file_path
=
"dcc.csv"
language_column
=
"video_language"
video_url_column
=
"video_link"
output_dir
=
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual/raw"
num_workers
=
50
unwanted_channels
=
[
"TEDx Talks"
,
"DaePyeong Mukbang"
,
"Joeman"
]
main
(
csv_file_path
,
language_column
,
video_url_column
,
output_dir
,
num_workers
)
LatentSync/tools/move_files_recur.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
from
tqdm
import
tqdm
paths
=
[]
def
gather_paths
(
input_dir
,
output_dir
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
for
video
in
sorted
(
os
.
listdir
(
input_dir
)):
if
video
.
endswith
(
".mp4"
):
video_input
=
os
.
path
.
join
(
input_dir
,
video
)
video_output
=
os
.
path
.
join
(
output_dir
,
video
)
if
os
.
path
.
isfile
(
video_output
):
continue
paths
.
append
([
video_input
,
output_dir
])
elif
os
.
path
.
isdir
(
os
.
path
.
join
(
input_dir
,
video
)):
gather_paths
(
os
.
path
.
join
(
input_dir
,
video
),
os
.
path
.
join
(
output_dir
,
video
))
def
main
(
input_dir
,
output_dir
):
print
(
f
"Recursively gathering video paths of
{
input_dir
}
..."
)
gather_paths
(
input_dir
,
output_dir
)
for
video_input
,
output_dir
in
tqdm
(
paths
):
shutil
.
move
(
video_input
,
output_dir
)
if
__name__
==
"__main__"
:
input_dir
=
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual_dcc"
output_dir
=
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual"
main
(
input_dir
,
output_dir
)
LatentSync/tools/occupy_gpu.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
os
import
torch.multiprocessing
as
mp
import
time
def
check_mem
(
cuda_device
):
devices_info
=
(
os
.
popen
(
'"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
)
.
read
()
.
strip
()
.
split
(
"
\n
"
)
)
total
,
used
=
devices_info
[
int
(
cuda_device
)].
split
(
","
)
return
total
,
used
def
loop
(
cuda_device
):
cuda_i
=
torch
.
device
(
f
"cuda:
{
cuda_device
}
"
)
total
,
used
=
check_mem
(
cuda_device
)
total
=
int
(
total
)
used
=
int
(
used
)
max_mem
=
int
(
total
*
0.9
)
block_mem
=
max_mem
-
used
while
True
:
x
=
torch
.
rand
(
20
,
512
,
512
,
dtype
=
torch
.
float
,
device
=
cuda_i
)
y
=
torch
.
rand
(
20
,
512
,
512
,
dtype
=
torch
.
float
,
device
=
cuda_i
)
time
.
sleep
(
0.001
)
x
=
torch
.
matmul
(
x
,
y
)
def
main
():
if
torch
.
cuda
.
is_available
():
num_processes
=
torch
.
cuda
.
device_count
()
processes
=
list
()
for
i
in
range
(
num_processes
):
p
=
mp
.
Process
(
target
=
loop
,
args
=
(
i
,))
p
.
start
()
processes
.
append
(
p
)
for
p
in
processes
:
p
.
join
()
if
__name__
==
"__main__"
:
torch
.
multiprocessing
.
set_start_method
(
"spawn"
)
main
()
LatentSync/tools/remove_outdated_files.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
def
remove_outdated_files
(
input_dir
,
begin_date
,
end_date
):
# Remove files from a specific time period
for
subdir
in
os
.
listdir
(
input_dir
):
if
subdir
>=
begin_date
and
subdir
<=
end_date
:
subdir_path
=
os
.
path
.
join
(
input_dir
,
subdir
)
command
=
f
"rm -rf
{
subdir_path
}
"
subprocess
.
run
(
command
,
shell
=
True
)
print
(
f
"Deleted:
{
subdir_path
}
"
)
if
__name__
==
"__main__"
:
input_dir
=
"/mnt/bn/video-datasets/output/syncnet"
begin_date
=
"train-2024_06_19-16:25:44"
end_date
=
"train-2024_08_03-07:39:58"
remove_outdated_files
(
input_dir
,
begin_date
,
end_date
)
LatentSync/tools/write_fileslist.py
0 → 100644
View file @
5c023842
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
tqdm
import
tqdm
from
latentsync.utils.util
import
gather_video_paths_recursively
def
write_fileslist
(
fileslist_path
):
with
open
(
fileslist_path
,
"w"
)
as
_
:
pass
def
append_fileslist
(
fileslist_path
,
video_paths
):
with
open
(
fileslist_path
,
"a"
)
as
f
:
for
video_path
in
tqdm
(
video_paths
):
f
.
write
(
f
"
{
video_path
}
\n
"
)
def
process_input_dir
(
fileslist_path
,
input_dir
):
print
(
f
"Processing input dir:
{
input_dir
}
"
)
video_paths
=
gather_video_paths_recursively
(
input_dir
)
append_fileslist
(
fileslist_path
,
video_paths
)
if
__name__
==
"__main__"
:
fileslist_path
=
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt"
write_fileslist
(
fileslist_path
)
process_input_dir
(
fileslist_path
,
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/train"
)
process_input_dir
(
fileslist_path
,
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/high_visual_quality/train"
)
process_input_dir
(
fileslist_path
,
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/avatars/high_visual_quality/train"
)
process_input_dir
(
fileslist_path
,
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual/high_visual_quality"
)
process_input_dir
(
fileslist_path
,
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/celebv_text/high_visual_quality/train"
)
process_input_dir
(
fileslist_path
,
"/mnt/bn/maliva-gen-ai-v2/chunyu.li/youtube/high_visual_quality"
)
LatentSync/train_syncnet.sh
0 → 100644
View file @
5c023842
#!/bin/bash
torchrun
--nnodes
=
1
--nproc_per_node
=
1
--master_port
=
25678
-m
scripts.train_syncnet
\
--config_path
"configs/syncnet/syncnet_16_pixel.yaml"
LatentSync/train_unet.sh
0 → 100644
View file @
5c023842
#!/bin/bash
torchrun
--nnodes
=
1
--nproc_per_node
=
1
--master_port
=
25678
-m
scripts.train_unet
\
--unet_config_path
"configs/unet/first_stage.yaml"
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment