Commit b469676c authored by helloyongyang's avatar helloyongyang
Browse files

update log

parent 9e3680b7
import argparse
import json
import torch.distributed as dist
from loguru import logger
......@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import set_config, set_parallel_config
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all
......@@ -70,17 +69,16 @@ def main():
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args()
logger.info(f"args: {args}")
# set config
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
if config.parallel:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
set_parallel_config(config)
print_config(config)
with ProfilingContext("Total Cost"):
runner = init_runner(config)
runner.run_pipeline()
......
......@@ -256,8 +256,6 @@ class WanModel:
if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
for key in sorted(synced_meta_dict.keys()):
if is_weight_loader:
......
......@@ -69,9 +69,6 @@ def set_config(args):
def set_parallel_config(config):
if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
......@@ -82,3 +79,13 @@ def set_parallel_config(config):
if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1:
config["cfg_parallel"] = True
def print_config(config):
config_to_print = config.copy()
config_to_print.pop("device_mesh", None)
if config.parallel:
if dist.get_rank() == 0:
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
else:
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment