Commit 27c5575f authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

Support Multi Levels Profile Log (#290)

parent dd870f3f
...@@ -46,7 +46,7 @@ gpu_id=0 ...@@ -46,7 +46,7 @@ gpu_id=0
export CUDA_VISIBLE_DEVICES=$gpu_id export CUDA_VISIBLE_DEVICES=$gpu_id
export CUDA_LAUNCH_BLOCKING=1 export CUDA_LAUNCH_BLOCKING=1
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export PROFILING_DEBUG_LEVEL=2
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# ==================== Parameter Parsing ==================== # ==================== Parameter Parsing ====================
......
...@@ -45,7 +45,7 @@ set gpu_id=0 ...@@ -45,7 +45,7 @@ set gpu_id=0
REM ==================== Environment Variables Setup ==================== REM ==================== Environment Variables Setup ====================
set CUDA_VISIBLE_DEVICES=%gpu_id% set CUDA_VISIBLE_DEVICES=%gpu_id%
set PYTHONPATH=%lightx2v_path%;%PYTHONPATH% set PYTHONPATH=%lightx2v_path%;%PYTHONPATH%
set ENABLE_PROFILING_DEBUG=true set PROFILING_DEBUG_LEVEL=2
set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
REM ==================== Parameter Parsing ==================== REM ==================== Parameter Parsing ====================
......
...@@ -21,7 +21,7 @@ from lightx2v.deploy.server.auth import AuthManager ...@@ -21,7 +21,7 @@ from lightx2v.deploy.server.auth import AuthManager
from lightx2v.deploy.server.metrics import MetricMonitor from lightx2v.deploy.server.metrics import MetricMonitor
from lightx2v.deploy.server.monitor import ServerMonitor, WorkerStatus from lightx2v.deploy.server.monitor import ServerMonitor, WorkerStatus
from lightx2v.deploy.task_manager import LocalTaskManager, PostgresSQLTaskManager, TaskStatus from lightx2v.deploy.task_manager import LocalTaskManager, PostgresSQLTaskManager, TaskStatus
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import *
from lightx2v.utils.service_utils import ProcessManager from lightx2v.utils.service_utils import ProcessManager
# ========================= # =========================
...@@ -679,7 +679,7 @@ if __name__ == "__main__": ...@@ -679,7 +679,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"): with ProfilingContext4DebugL1("Init Server Cost"):
model_pipelines = Pipeline(args.pipeline_json) model_pipelines = Pipeline(args.pipeline_json)
auth_manager = AuthManager() auth_manager = AuthManager()
if args.task_url.startswith("/"): if args.task_url.startswith("/"):
......
...@@ -16,14 +16,14 @@ from lightx2v.deploy.common.utils import class_try_catch_async ...@@ -16,14 +16,14 @@ from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.infer import init_runner # noqa from lightx2v.infer import init_runner # noqa
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.utils.envs import CHECK_ENABLE_GRAPH_MODE from lightx2v.utils.envs import CHECK_ENABLE_GRAPH_MODE
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import set_config, set_parallel_config from lightx2v.utils.set_config import set_config, set_parallel_config
from lightx2v.utils.utils import seed_all from lightx2v.utils.utils import seed_all
class BaseWorker: class BaseWorker:
@ProfilingContext("Init Worker Worker Cost:") @ProfilingContext4DebugL1("Init Worker Worker Cost:")
def __init__(self, args): def __init__(self, args):
config = set_config(args) config = set_config(args)
config["mode"] = "" config["mode"] = ""
......
...@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # ...@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner #
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401 from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401 from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all from lightx2v.utils.utils import seed_all
...@@ -103,7 +103,7 @@ def main(): ...@@ -103,7 +103,7 @@ def main():
print_config(config) print_config(config)
with ProfilingContext("Total Cost"): with ProfilingContext4DebugL1("Total Cost"):
runner = init_runner(config) runner = init_runner(config)
runner.run_pipeline() runner.run_pipeline()
......
...@@ -10,7 +10,7 @@ from requests.exceptions import RequestException ...@@ -10,7 +10,7 @@ from requests.exceptions import RequestException
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import *
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from .base_runner import BaseRunner from .base_runner import BaseRunner
...@@ -60,7 +60,7 @@ class DefaultRunner(BaseRunner): ...@@ -60,7 +60,7 @@ class DefaultRunner(BaseRunner):
else: else:
raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}") raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
@ProfilingContext("Load models") @ProfilingContext4DebugL2("Load models")
def load_model(self): def load_model(self):
self.model = self.load_transformer() self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder() self.text_encoders = self.load_text_encoder()
...@@ -116,13 +116,13 @@ class DefaultRunner(BaseRunner): ...@@ -116,13 +116,13 @@ class DefaultRunner(BaseRunner):
self.check_stop() self.check_stop()
logger.info(f"==> step_index: {step_index + 1} / {total_steps}") logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("🚀 infer_main"): with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs) self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"): with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.progress_callback: if self.progress_callback:
...@@ -155,7 +155,7 @@ class DefaultRunner(BaseRunner): ...@@ -155,7 +155,7 @@ class DefaultRunner(BaseRunner):
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda() img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
return img, img_ori return img, img_ori
@ProfilingContext("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2v(self): def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img, img_ori = self.read_image_input(self.config["image_path"]) img, img_ori = self.read_image_input(self.config["image_path"])
...@@ -166,7 +166,7 @@ class DefaultRunner(BaseRunner): ...@@ -166,7 +166,7 @@ class DefaultRunner(BaseRunner):
gc.collect() gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
@ProfilingContext("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2v(self): def _run_input_encoder_local_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, None) text_encoder_output = self.run_text_encoder(prompt, None)
...@@ -177,7 +177,7 @@ class DefaultRunner(BaseRunner): ...@@ -177,7 +177,7 @@ class DefaultRunner(BaseRunner):
"image_encoder_output": None, "image_encoder_output": None,
} }
@ProfilingContext("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_flf2v(self): def _run_input_encoder_local_flf2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
first_frame, _ = self.read_image_input(self.config["image_path"]) first_frame, _ = self.read_image_input(self.config["image_path"])
...@@ -189,7 +189,7 @@ class DefaultRunner(BaseRunner): ...@@ -189,7 +189,7 @@ class DefaultRunner(BaseRunner):
gc.collect() gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)
@ProfilingContext("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_vace(self): def _run_input_encoder_local_vace(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
src_video = self.config.get("src_video", None) src_video = self.config.get("src_video", None)
...@@ -219,12 +219,12 @@ class DefaultRunner(BaseRunner): ...@@ -219,12 +219,12 @@ class DefaultRunner(BaseRunner):
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v": if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None self.inputs["image_encoder_output"]["vae_encoder_out"] = None
@ProfilingContext("Run DiT") @ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None): def run_main(self, total_steps=None):
self.init_run() self.init_run()
for segment_idx in range(self.video_segment_num): for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}") logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"): with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
self.check_stop() self.check_stop()
# 1. default do nothing # 1. default do nothing
self.init_run_segment(segment_idx) self.init_run_segment(segment_idx)
...@@ -236,7 +236,7 @@ class DefaultRunner(BaseRunner): ...@@ -236,7 +236,7 @@ class DefaultRunner(BaseRunner):
self.end_run_segment() self.end_run_segment()
self.end_run() self.end_run()
@ProfilingContext("Run VAE Decoder") @ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
......
from loguru import logger from loguru import logger
from lightx2v.utils.profiler import ProfilingContext4Debug from lightx2v.utils.profiler import *
class GraphRunner: class GraphRunner:
...@@ -13,7 +13,7 @@ class GraphRunner: ...@@ -13,7 +13,7 @@ class GraphRunner:
logger.info("🚀 Starting Model Compilation - Please wait, this may take a while... 🚀") logger.info("🚀 Starting Model Compilation - Please wait, this may take a while... 🚀")
logger.info("=" * 60) logger.info("=" * 60)
with ProfilingContext4Debug("compile"): with ProfilingContext4DebugL2("compile"):
self.runner.run_step() self.runner.run_step()
logger.info("=" * 60) logger.info("=" * 60)
......
...@@ -10,7 +10,7 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel ...@@ -10,7 +10,7 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -32,7 +32,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -32,7 +32,7 @@ class QwenImageRunner(DefaultRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ProfilingContext("Load models") @ProfilingContext4DebugL2("Load models")
def load_model(self): def load_model(self):
self.model = self.load_transformer() self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder() self.text_encoders = self.load_text_encoder()
...@@ -69,7 +69,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -69,7 +69,7 @@ class QwenImageRunner(DefaultRunner):
else: else:
assert NotImplementedError assert NotImplementedError
@ProfilingContext("Run DiT") @ProfilingContext4DebugL2("Run DiT")
def _run_dit_local(self, total_steps=None): def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
...@@ -81,7 +81,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -81,7 +81,7 @@ class QwenImageRunner(DefaultRunner):
self.end_run() self.end_run()
return latents, generator return latents, generator
@ProfilingContext("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2i(self): def _run_input_encoder_local_t2i(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt) text_encoder_output = self.run_text_encoder(prompt)
...@@ -92,7 +92,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -92,7 +92,7 @@ class QwenImageRunner(DefaultRunner):
"image_encoder_output": None, "image_encoder_output": None,
} }
@ProfilingContext("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2i(self): def _run_input_encoder_local_i2i(self):
image = Image.open(self.config["image_path"]) image = Image.open(self.config["image_path"])
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
...@@ -125,20 +125,18 @@ class QwenImageRunner(DefaultRunner): ...@@ -125,20 +125,18 @@ class QwenImageRunner(DefaultRunner):
return {"image_latents": image_latents} return {"image_latents": image_latents}
def run(self, total_steps=None): def run(self, total_steps=None):
from lightx2v.utils.profiler import ProfilingContext4Debug
if total_steps is None: if total_steps is None:
total_steps = self.model.scheduler.infer_steps total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps): for step_index in range(total_steps):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}") logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("🚀 infer_main"): with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs) self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"): with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.progress_callback: if self.progress_callback:
...@@ -181,7 +179,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -181,7 +179,7 @@ class QwenImageRunner(DefaultRunner):
def run_image_encoder(self): def run_image_encoder(self):
pass pass
@ProfilingContext("Load models") @ProfilingContext4DebugL2("Load models")
def load_model(self): def load_model(self):
self.model = self.load_transformer() self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder() self.text_encoders = self.load_text_encoder()
...@@ -189,7 +187,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -189,7 +187,7 @@ class QwenImageRunner(DefaultRunner):
self.vae = self.load_vae() self.vae = self.load_vae()
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
@ProfilingContext("Run VAE Decoder") @ProfilingContext4DebugL1("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator): def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae() self.vae_decoder = self.load_vae()
......
...@@ -25,7 +25,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner ...@@ -25,7 +25,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import EulerScheduler from lightx2v.models.schedulers.wan.audio.scheduler import EulerScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image
...@@ -368,7 +368,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -368,7 +368,7 @@ class WanAudioRunner(WanRunner): # type:ignore
gc.collect() gc.collect()
return vae_encoder_out return vae_encoder_out
@ProfilingContext("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_r2v_audio(self): def _run_input_encoder_local_r2v_audio(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = self.read_image_input(self.config["image_path"]) img = self.read_image_input(self.config["image_path"])
...@@ -410,7 +410,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -410,7 +410,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
_, nframe, height, width = self.model.scheduler.latents.shape _, nframe, height, width = self.model.scheduler.latents.shape
with ProfilingContext4Debug("vae_encoder in init run segment"): with ProfilingContext4DebugL1("vae_encoder in init run segment"):
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
if prev_video is not None: if prev_video is not None:
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype)) prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
...@@ -460,7 +460,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -460,7 +460,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.cut_audio_list = [] self.cut_audio_list = []
self.prev_video = None self.prev_video = None
@ProfilingContext4Debug("Init run segment") @ProfilingContext4DebugL1("Init run segment")
def init_run_segment(self, segment_idx, audio_array=None): def init_run_segment(self, segment_idx, audio_array=None):
self.segment_idx = segment_idx self.segment_idx = segment_idx
if audio_array is not None: if audio_array is not None:
...@@ -485,7 +485,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -485,7 +485,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if segment_idx > 0: if segment_idx > 0:
self.model.scheduler.reset(self.inputs["previmg_encoder_output"]) self.model.scheduler.reset(self.inputs["previmg_encoder_output"])
@ProfilingContext4Debug("End run segment") @ProfilingContext4DebugL1("End run segment")
def end_run_segment(self): def end_run_segment(self):
self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float) self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
useful_length = self.segment.end_frame - self.segment.start_frame useful_length = self.segment.end_frame - self.segment.start_frame
...@@ -575,7 +575,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -575,7 +575,7 @@ class WanAudioRunner(WanRunner): # type:ignore
max_fail_count = 10 max_fail_count = 10
while True: while True:
with ProfilingContext4Debug(f"stream segment get audio segment {segment_idx}"): with ProfilingContext4DebugL1(f"stream segment get audio segment {segment_idx}"):
self.check_stop() self.check_stop()
audio_array = self.va_reader.get_audio_segment(timeout=fetch_timeout) audio_array = self.va_reader.get_audio_segment(timeout=fetch_timeout)
if audio_array is None: if audio_array is None:
...@@ -585,7 +585,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -585,7 +585,7 @@ class WanAudioRunner(WanRunner): # type:ignore
raise Exception(f"Failed to get audio chunk {fail_count} times, stop reader") raise Exception(f"Failed to get audio chunk {fail_count} times, stop reader")
continue continue
with ProfilingContext4Debug(f"stream segment end2end {segment_idx}"): with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"):
fail_count = 0 fail_count = 0
self.init_run_segment(segment_idx, audio_array) self.init_run_segment(segment_idx, audio_array)
latents = self.run_segment(total_steps=None) latents = self.run_segment(total_steps=None)
...@@ -603,7 +603,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -603,7 +603,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.va_recorder.stop(wait=False) self.va_recorder.stop(wait=False)
self.va_recorder = None self.va_recorder = None
@ProfilingContext4Debug("Process after vae decoder") @ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=True): def process_images_after_vae_decoder(self, save_video=True):
# Merge results # Merge results
gen_lvideo = torch.cat(self.gen_video_list, dim=2).float() gen_lvideo = torch.cat(self.gen_video_list, dim=2).float()
...@@ -728,12 +728,12 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -728,12 +728,12 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter.load_state_dict(weights_dict, strict=False) audio_adapter.load_state_dict(weights_dict, strict=False)
return audio_adapter.to(dtype=GET_DTYPE()) return audio_adapter.to(dtype=GET_DTYPE())
@ProfilingContext("Load models")
def load_model(self): def load_model(self):
super().load_model() super().load_model()
self.audio_encoder = self.load_audio_encoder() with ProfilingContext4DebugL2("Load audio encoder and adapter"):
self.audio_adapter = self.load_audio_adapter() self.audio_encoder = self.load_audio_encoder()
self.model.set_audio_adapter(self.audio_adapter) self.audio_adapter = self.load_audio_adapter()
self.model.set_audio_adapter(self.audio_adapter)
def set_target_shape(self): def set_target_shape(self):
"""Set target shape for generation""" """Set target shape for generation"""
......
...@@ -9,7 +9,7 @@ from lightx2v.models.networks.wan.model import WanModel ...@@ -9,7 +9,7 @@ from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext4Debug from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -85,11 +85,11 @@ class WanCausVidRunner(WanRunner): ...@@ -85,11 +85,11 @@ class WanCausVidRunner(WanRunner):
if fragment_idx > 0: if fragment_idx > 0:
logger.info("recompute the kv_cache ...") logger.info("recompute the kv_cache ...")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.latents = self.model.scheduler.last_sample self.model.scheduler.latents = self.model.scheduler.last_sample
self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1) self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1)
with ProfilingContext4Debug("🚀 infer_main"): with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs, kv_start, kv_end) self.model.infer(self.inputs, kv_start, kv_end)
kv_start += self.num_frame_per_block * self.frame_seq_length kv_start += self.num_frame_per_block * self.frame_seq_length
...@@ -105,13 +105,13 @@ class WanCausVidRunner(WanRunner): ...@@ -105,13 +105,13 @@ class WanCausVidRunner(WanRunner):
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("🚀 infer_main"): with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs, kv_start, kv_end) self.model.infer(self.inputs, kv_start, kv_end)
with ProfilingContext4Debug("step_post"): with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post() self.model.scheduler.step_post()
kv_start += self.num_frame_per_block * self.frame_seq_length kv_start += self.num_frame_per_block * self.frame_seq_length
......
...@@ -10,7 +10,7 @@ from loguru import logger ...@@ -10,7 +10,7 @@ from loguru import logger
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -55,9 +55,9 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I ...@@ -55,9 +55,9 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
def run_input_encoder(self): def run_input_encoder(self):
image_encoder_output = None image_encoder_output = None
if os.path.isfile(self.config.image_path): if os.path.isfile(self.config.image_path):
with ProfilingContext("Run Img Encoder"): with ProfilingContext4DebugL2("Run Img Encoder"):
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model) image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
with ProfilingContext("Run Text Encoder"): with ProfilingContext4DebugL2("Run Text Encoder"):
text_encoder_output = self.run_text_encoder(self.config["prompt"], self.text_encoders, self.config, image_encoder_output) text_encoder_output = self.run_text_encoder(self.config["prompt"], self.text_encoders, self.config, image_encoder_output)
self.set_target_shape() self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output} self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
...@@ -107,13 +107,13 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I ...@@ -107,13 +107,13 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("🚀 infer_main"): with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs) self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"): with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post() self.model.scheduler.step_post()
videos = self.run_vae(self.model.scheduler.latents, self.model.scheduler.generator) videos = self.run_vae(self.model.scheduler.latents, self.model.scheduler.generator)
......
...@@ -9,7 +9,7 @@ from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProce ...@@ -9,7 +9,7 @@ from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProce
from lightx2v.models.networks.wan.vace_model import WanVaceModel from lightx2v.models.networks.wan.vace_model import WanVaceModel
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -159,7 +159,7 @@ class WanVaceRunner(WanRunner): ...@@ -159,7 +159,7 @@ class WanVaceRunner(WanRunner):
target_shape[0] = int(target_shape[0] / 2) target_shape[0] = int(target_shape[0] / 2)
self.config.target_shape = target_shape self.config.target_shape = target_shape
@ProfilingContext("Run VAE Decoder") @ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
......
...@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple ...@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import *
class RIFEWrapper: class RIFEWrapper:
...@@ -25,12 +25,12 @@ class RIFEWrapper: ...@@ -25,12 +25,12 @@ class RIFEWrapper:
from .train_log.RIFE_HDv3 import Model from .train_log.RIFE_HDv3 import Model
self.model = Model() self.model = Model()
with ProfilingContext("Load RIFE model"): with ProfilingContext4DebugL2("Load RIFE model"):
self.model.load_model(model_path, -1) self.model.load_model(model_path, -1)
self.model.eval() self.model.eval()
self.model.device() self.model.device()
@ProfilingContext("Interpolate frames") @ProfilingContext4DebugL2("Interpolate frames")
def interpolate_frames( def interpolate_frames(
self, self,
images: torch.Tensor, images: torch.Tensor,
......
...@@ -17,9 +17,9 @@ DTYPE_MAP = { ...@@ -17,9 +17,9 @@ DTYPE_MAP = {
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def CHECK_ENABLE_PROFILING_DEBUG(): def CHECK_PROFILING_DEBUG_LEVEL(target_level):
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true" current_level = int(os.getenv("PROFILING_DEBUG_LEVEL", "0"))
return ENABLE_PROFILING_DEBUG return current_level >= target_level
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
......
...@@ -12,7 +12,6 @@ from lightx2v.utils.envs import * ...@@ -12,7 +12,6 @@ from lightx2v.utils.envs import *
class _ProfilingContext: class _ProfilingContext:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.rank_info = ""
if dist.is_initialized(): if dist.is_initialized():
self.rank_info = f"Rank {dist.get_rank()}" self.rank_info = f"Rank {dist.get_rank()}"
else: else:
...@@ -80,5 +79,24 @@ class _NullContext: ...@@ -80,5 +79,24 @@ class _NullContext:
return func return func
ProfilingContext = _ProfilingContext class _ProfilingContextL1(_ProfilingContext):
ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext """Level 1 profiling context with Level1_Log prefix."""
def __init__(self, name):
super().__init__(f"Level1_Log {name}")
class _ProfilingContextL2(_ProfilingContext):
"""Level 2 profiling context with Level2_Log prefix."""
def __init__(self, name):
super().__init__(f"Level2_Log {name}")
"""
PROFILING_DEBUG_LEVEL=0: [Default] disable all profiling
PROFILING_DEBUG_LEVEL=1: enable ProfilingContext4DebugL1
PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4DebugL2
"""
ProfilingContext4DebugL1 = _ProfilingContextL1 if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext # if user >= 1, enable profiling
ProfilingContext4DebugL2 = _ProfilingContextL2 if CHECK_PROFILING_DEBUG_LEVEL(2) else _NullContext # if user >= 2, enable profiling
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
from loguru import logger from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import *
sys_prompt = """ sys_prompt = """
Transform the short prompt into a detailed video-generation caption using this structure: Transform the short prompt into a detailed video-generation caption using this structure:
...@@ -40,7 +40,7 @@ class PromptEnhancer: ...@@ -40,7 +40,7 @@ class PromptEnhancer:
def to_device(self, device): def to_device(self, device):
self.model = self.model.to(device) self.model = self.model.to(device)
@ProfilingContext("Run prompt enhancer") @ProfilingContext4DebugL1("Run prompt enhancer")
@torch.no_grad() @torch.no_grad()
def __call__(self, prompt): def __call__(self, prompt):
prompt = prompt.strip() prompt = prompt.strip()
......
...@@ -32,12 +32,12 @@ export DTYPE=BF16 ...@@ -32,12 +32,12 @@ export DTYPE=BF16
# Note: If set to FP32, it will be slower, so we recommend set ENABLE_GRAPH_MODE to true. # Note: If set to FP32, it will be slower, so we recommend set ENABLE_GRAPH_MODE to true.
export SENSITIVE_LAYER_DTYPE=FP32 export SENSITIVE_LAYER_DTYPE=FP32
# Performance Profiling Debug Mode (Debug Only) # Performance Profiling Debug Level (Debug Only)
# Enables detailed performance analysis output, such as time cost and memory usage # Enables detailed performance analysis output, such as time cost and memory usage
# Available options: [true, false] # Available options: [0, 1, 2]
# If not set, default value: false # If not set, default value: 0
# Note: This option can be set to false for production. # Note: This option can be set to 0 for production.
export ENABLE_PROFILING_DEBUG=true export PROFILING_DEBUG_LEVEL=2
# Graph Mode Optimization (Performance Enhancement) # Graph Mode Optimization (Performance Enhancement)
# Enables torch.compile for graph optimization, can improve inference performance # Enables torch.compile for graph optimization, can improve inference performance
...@@ -56,6 +56,6 @@ echo "model_path: ${model_path}" ...@@ -56,6 +56,6 @@ echo "model_path: ${model_path}"
echo "-------------------------------------------------------------------------------" echo "-------------------------------------------------------------------------------"
echo "Model Inference Data Type: ${DTYPE}" echo "Model Inference Data Type: ${DTYPE}"
echo "Sensitive Layer Data Type: ${SENSITIVE_LAYER_DTYPE}" echo "Sensitive Layer Data Type: ${SENSITIVE_LAYER_DTYPE}"
echo "Performance Profiling Debug Mode: ${ENABLE_PROFILING_DEBUG}" echo "Performance Profiling Debug Level: ${PROFILING_DEBUG_LEVEL}"
echo "Graph Mode Optimization: ${ENABLE_GRAPH_MODE}" echo "Graph Mode Optimization: ${ENABLE_GRAPH_MODE}"
echo "===============================================================================" echo "==============================================================================="
...@@ -27,7 +27,7 @@ export TOKENIZERS_PARALLELISM=false ...@@ -27,7 +27,7 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16 export DTYPE=BF16
export SENSITIVE_LAYER_DTYPE=FP32 export SENSITIVE_LAYER_DTYPE=FP32
export ENABLE_PROFILING_DEBUG=true export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \ python -m lightx2v.infer \
......
...@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false ...@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 export DTYPE=BF16
......
...@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false ...@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 export DTYPE=BF16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment