Commit cbf7820f authored by helloyongyang's avatar helloyongyang
Browse files

support _ProfilingContext and _NullContext for speed test

parent 75c03057
......@@ -8,8 +8,12 @@ import json
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
from contextlib import contextmanager
from PIL import Image
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.set_config import set_config
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
......@@ -27,19 +31,8 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.common.ops import *
from lightx2v.utils.set_config import set_config
@contextmanager
def time_duration(label: str = ""):
torch.cuda.synchronize()
start_time = time.time()
yield
torch.cuda.synchronize()
end_time = time.time()
print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds")
from lightx2v.common.ops import *
def load_models(config):
......@@ -63,7 +56,7 @@ def load_models(config):
vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config)
elif config.model_cls == "wan2.1":
with time_duration("Load Text Encoder"):
with ProfilingContext("Load Text Encoder"):
text_encoder = T5EncoderModel(
text_len=config["text_len"],
dtype=torch.bfloat16,
......@@ -73,20 +66,20 @@ def load_models(config):
shard_fn=None,
)
text_encoders = [text_encoder]
with time_duration("Load Wan Model"):
with ProfilingContext("Load Wan Model"):
model = WanModel(config.model_path, config, init_device)
if config.lora_path:
lora_wrapper = WanLoraWrapper(model)
with time_duration("Load LoRA Model"):
with ProfilingContext("Load LoRA Model"):
lora_name = lora_wrapper.load_lora(config.lora_path)
lora_wrapper.apply_lora(lora_name, config.strength_model)
print(f"Loaded LoRA: {lora_name}")
with time_duration("Load WAN VAE Model"):
with ProfilingContext("Load WAN VAE Model"):
vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae)
if config.task == "i2v":
with time_duration("Load Image Encoder"):
with ProfilingContext("Load Image Encoder"):
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
......@@ -280,27 +273,16 @@ def init_scheduler(config, image_encoder_output):
def run_main_inference(model, inputs):
for step_index in range(model.scheduler.infer_steps):
torch.cuda.synchronize()
time1 = time.time()
model.scheduler.step_pre(step_index=step_index)
torch.cuda.synchronize()
time2 = time.time()
model.infer(inputs)
torch.cuda.synchronize()
time3 = time.time()
print(f"==> step_index: {step_index + 1} / {model.scheduler.infer_steps}")
model.scheduler.step_post()
with ProfilingContext4Debug("step_pre"):
model.scheduler.step_pre(step_index=step_index)
torch.cuda.synchronize()
time4 = time.time()
with ProfilingContext4Debug("infer"):
model.infer(inputs)
print(f"step {step_index} infer time: {time3 - time2}")
print(f"step {step_index} all time: {time4 - time1}")
print("*" * 10)
with ProfilingContext4Debug("step_post"):
model.scheduler.step_post()
return model.scheduler.latents, model.scheduler.generator
......@@ -344,7 +326,7 @@ if __name__ == "__main__":
parser.add_argument("--strength_model", type=float, default=1.0)
args = parser.parse_args()
start_time = time.time()
start_time = time.perf_counter()
print(f"args: {args}")
seed_all(args.seed)
......@@ -356,7 +338,7 @@ if __name__ == "__main__":
print(f"config: {config}")
with time_duration("Load models"):
with ProfilingContext("Load models"):
model, text_encoders, vae_model, image_encoder = load_models(config)
if config["task"] in ["i2v"]:
......@@ -364,7 +346,7 @@ if __name__ == "__main__":
else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
with time_duration("Run Text Encoder"):
with ProfilingContext("Run Text Encoder"):
text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output)
inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
......@@ -383,15 +365,15 @@ if __name__ == "__main__":
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache()
with time_duration("Run VAE"):
with ProfilingContext("Run VAE"):
images = run_vae(latents, generator, config)
if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0):
with time_duration("Save video"):
with ProfilingContext("Save video"):
if config.model_cls == "wan2.1":
cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, config.save_video_path, fps=24)
end_time = time.time()
end_time = time.perf_counter()
print(f"Total cost: {end_time - start_time}")
import time
import os
import torch
from contextlib import ContextDecorator
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true"
class _ProfilingContext(ContextDecorator):
def __init__(self, name):
self.name = name
def __enter__(self):
torch.cuda.synchronize()
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
elapsed = time.perf_counter() - self.start_time
print(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False
class _NullContext(ContextDecorator):
# Context manager without decision branch logic overhead
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
return False
ProfilingContext = _ProfilingContext
ProfilingContext4Debug = _ProfilingContext if ENABLE_PROFILING_DEBUG else _NullContext
......@@ -23,6 +23,8 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
--model_path $model_path \
......
......@@ -23,6 +23,8 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
--model_path $model_path \
......
......@@ -23,6 +23,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
......
......@@ -23,6 +23,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
......
......@@ -29,6 +29,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \
......
......@@ -26,6 +26,8 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python -m lightx2v \
--model_cls wan2.1 \
--task i2v \
......
......@@ -29,6 +29,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \
......
......@@ -29,6 +29,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment