"vscode:/vscode.git/clone" did not exist on "d46b69723a82402db096d579dec123ef3f3dba86"
Commit 5f9377aa authored by mashun1's avatar mashun1
Browse files

latte

parent 5bd891c3
File mode changed from 100644 to 100755
......@@ -35,8 +35,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from utils import (clip_grad_norm_, create_logger, update_ema,
requires_grad, cleanup, create_tensorboard,
write_tensorboard, setup_distributed, fetch_files_by_numbers,
get_experiment_dir, separation_content_motion,)
write_tensorboard, setup_distributed, get_experiment_dir)
#################################################################################
......@@ -166,42 +165,10 @@ def main(args):
# Freeze vae and text_encoder
vae.requires_grad_(False)
if args.extras == 78:
text_encoder.requires_grad_(False)
if args.dataset == 'webvideo2mlaion':
# Setup video dataset:
file_list = os.listdir(args.image_data_path) # all file format must be the same!
file_count = int(len(file_list) / dist.get_world_size())
args.laion_meta_files = fetch_files_by_numbers(rank * file_count, file_count, file_list)
file_list = os.listdir(args.webvideo_data_path) # all file format must be the same!
file_count = int(len(file_list) / dist.get_world_size())
args.webvideo_meta_files = fetch_files_by_numbers(rank * file_count, file_count, file_list)
if args.test_run:
args.laion_meta_files = ['file_000.csv']
args.webvideo_meta_files = ['file_000.csv']
# Setup data:
dataset = get_dataset(args)
if args.dataset == 'webvideo2mlaion':
sampler = DistributedSampler(
dataset,
num_replicas=1, # important
rank=0, # important
shuffle=True,
seed=args.global_seed
)
else:
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=rank,
shuffle=True,
seed=args.global_seed
)
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
......@@ -218,8 +185,9 @@ def main(args):
pin_memory=True,
drop_last=True
)
logger.info(f"Dataset contains {len(dataset):,} videos ({args.webvideo_data_path})")
# logger.info(f"Dataset contains {len(dataset):,} videos ({args.webvideo_data_path})")
logger.info(f"Dataset contains {len(dataset):,} videos ({args.data_path})")
# Scheduler
lr_scheduler = get_scheduler(
name="constant",
......@@ -358,4 +326,4 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
args = parser.parse_args()
main(OmegaConf.load(args.config))
main(OmegaConf.load(args.config))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment