"vscode:/vscode.git/clone" did not exist on "71c08f3927926de64a9054e327d04539f10d4564"
Commit b469676c authored by helloyongyang's avatar helloyongyang
Browse files

update log

parent 9e3680b7
import argparse import argparse
import json
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
...@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D ...@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import set_config, set_parallel_config from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all from lightx2v.utils.utils import seed_all
...@@ -70,17 +69,16 @@ def main(): ...@@ -70,17 +69,16 @@ def main():
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}")
# set config # set config
config = set_config(args) config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
if config.parallel: if config.parallel:
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
set_parallel_config(config) set_parallel_config(config)
print_config(config)
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
runner = init_runner(config) runner = init_runner(config)
runner.run_pipeline() runner.run_pipeline()
......
...@@ -256,8 +256,6 @@ class WanModel: ...@@ -256,8 +256,6 @@ class WanModel:
if target_device == "cuda": if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
for key in sorted(synced_meta_dict.keys()): for key in sorted(synced_meta_dict.keys()):
if is_weight_loader: if is_weight_loader:
......
...@@ -69,9 +69,6 @@ def set_config(args): ...@@ -69,9 +69,6 @@ def set_config(args):
def set_parallel_config(config): def set_parallel_config(config):
if config.parallel: if config.parallel:
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
cfg_p_size = config.parallel.get("cfg_p_size", 1) cfg_p_size = config.parallel.get("cfg_p_size", 1)
seq_p_size = config.parallel.get("seq_p_size", 1) seq_p_size = config.parallel.get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size" assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
...@@ -82,3 +79,13 @@ def set_parallel_config(config): ...@@ -82,3 +79,13 @@ def set_parallel_config(config):
if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1: if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1:
config["cfg_parallel"] = True config["cfg_parallel"] = True
def print_config(config):
config_to_print = config.copy()
config_to_print.pop("device_mesh", None)
if config.parallel:
if dist.get_rank() == 0:
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
else:
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment