Commit d7206e69 authored by helloyongyang's avatar helloyongyang
Browse files

fix gpu mem not balanced bug

parent c0b36010
...@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D ...@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import print_config, set_config from lightx2v.utils.set_config import set_config, set_parallel_config
from lightx2v.utils.utils import seed_all from lightx2v.utils.utils import seed_all
...@@ -58,11 +58,17 @@ def main(): ...@@ -58,11 +58,17 @@ def main():
logger.info(f"args: {args}") logger.info(f"args: {args}")
# set config
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
if "parallel" in config:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
set_parallel_config(config)
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
config = set_config(args)
print_config(config)
runner = init_runner(config) runner = init_runner(config)
runner.run_pipeline() runner.run_pipeline()
# Clean up distributed process group # Clean up distributed process group
......
...@@ -43,9 +43,6 @@ class DefaultRunner(BaseRunner): ...@@ -43,9 +43,6 @@ class DefaultRunner(BaseRunner):
self.run_input_encoder = self._run_input_encoder_local_t2v self.run_input_encoder = self._run_input_encoder_local_t2v
def set_init_device(self): def set_init_device(self):
if self.config.parallel:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
if self.config.cpu_offload: if self.config.cpu_offload:
self.init_device = torch.device("cpu") self.init_device = torch.device("cpu")
else: else:
......
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
from functools import wraps from functools import wraps
import torch import torch
import torch.distributed as dist
from loguru import logger from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -12,9 +13,10 @@ class _ProfilingContext: ...@@ -12,9 +13,10 @@ class _ProfilingContext:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.rank_info = "" self.rank_info = ""
if torch.distributed.is_available() and torch.distributed.is_initialized(): if dist.is_initialized():
rank = torch.distributed.get_rank() self.rank_info = f"Rank {dist.get_rank()}"
self.rank_info = f"Rank {rank} - " else:
self.rank_info = "Single GPU"
def __enter__(self): def __enter__(self):
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -27,11 +29,11 @@ class _ProfilingContext: ...@@ -27,11 +29,11 @@ class _ProfilingContext:
torch.cuda.synchronize() torch.cuda.synchronize()
if torch.cuda.is_available(): if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB
logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB") logger.info(f"[Profile] {self.rank_info} - {self.name} Peak Memory: {peak_memory:.2f} GB")
else: else:
logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.") logger.info(f"[Profile] {self.rank_info} - {self.name} executed without GPU.")
elapsed = time.perf_counter() - self.start_time elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds") logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
return False return False
async def __aenter__(self): async def __aenter__(self):
...@@ -45,11 +47,11 @@ class _ProfilingContext: ...@@ -45,11 +47,11 @@ class _ProfilingContext:
torch.cuda.synchronize() torch.cuda.synchronize()
if torch.cuda.is_available(): if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB
logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB") logger.info(f"[Profile] {self.rank_info} - {self.name} Peak Memory: {peak_memory:.2f} GB")
else: else:
logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.") logger.info(f"[Profile] {self.rank_info} - {self.name} executed without GPU.")
elapsed = time.perf_counter() - self.start_time elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds") logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
return False return False
def __call__(self, func): def __call__(self, func):
......
...@@ -61,8 +61,6 @@ def set_config(args): ...@@ -61,8 +61,6 @@ def set_config(args):
logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.") logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1 config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
set_parallel_config(config) # parallel config
return config return config
...@@ -83,9 +81,3 @@ def set_parallel_config(config): ...@@ -83,9 +81,3 @@ def set_parallel_config(config):
if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1: if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1:
config["cfg_parallel"] = True config["cfg_parallel"] = True
def print_config(config):
config_to_print = config.copy()
config_to_print.pop("device_mesh", None) # Remove device_mesh if it exists
logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment