Commit cbf7820f authored by helloyongyang's avatar helloyongyang
Browse files

support _ProfilingContext and _NullContext for speed test

parent 75c03057
...@@ -8,8 +8,12 @@ import json ...@@ -8,8 +8,12 @@ import json
import torchvision import torchvision
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
import numpy as np import numpy as np
from contextlib import contextmanager
from PIL import Image from PIL import Image
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.set_config import set_config
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
...@@ -27,19 +31,8 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper ...@@ -27,19 +31,8 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.common.ops import *
from lightx2v.utils.set_config import set_config
from lightx2v.common.ops import *
@contextmanager
def time_duration(label: str = ""):
torch.cuda.synchronize()
start_time = time.time()
yield
torch.cuda.synchronize()
end_time = time.time()
print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds")
def load_models(config): def load_models(config):
...@@ -63,7 +56,7 @@ def load_models(config): ...@@ -63,7 +56,7 @@ def load_models(config):
vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config) vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config)
elif config.model_cls == "wan2.1": elif config.model_cls == "wan2.1":
with time_duration("Load Text Encoder"): with ProfilingContext("Load Text Encoder"):
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=config["text_len"], text_len=config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
...@@ -73,20 +66,20 @@ def load_models(config): ...@@ -73,20 +66,20 @@ def load_models(config):
shard_fn=None, shard_fn=None,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
with time_duration("Load Wan Model"): with ProfilingContext("Load Wan Model"):
model = WanModel(config.model_path, config, init_device) model = WanModel(config.model_path, config, init_device)
if config.lora_path: if config.lora_path:
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
with time_duration("Load LoRA Model"): with ProfilingContext("Load LoRA Model"):
lora_name = lora_wrapper.load_lora(config.lora_path) lora_name = lora_wrapper.load_lora(config.lora_path)
lora_wrapper.apply_lora(lora_name, config.strength_model) lora_wrapper.apply_lora(lora_name, config.strength_model)
print(f"Loaded LoRA: {lora_name}") print(f"Loaded LoRA: {lora_name}")
with time_duration("Load WAN VAE Model"): with ProfilingContext("Load WAN VAE Model"):
vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae) vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae)
if config.task == "i2v": if config.task == "i2v":
with time_duration("Load Image Encoder"): with ProfilingContext("Load Image Encoder"):
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=init_device, device=init_device,
...@@ -280,27 +273,16 @@ def init_scheduler(config, image_encoder_output): ...@@ -280,27 +273,16 @@ def init_scheduler(config, image_encoder_output):
def run_main_inference(model, inputs): def run_main_inference(model, inputs):
for step_index in range(model.scheduler.infer_steps): for step_index in range(model.scheduler.infer_steps):
torch.cuda.synchronize() print(f"==> step_index: {step_index + 1} / {model.scheduler.infer_steps}")
time1 = time.time()
model.scheduler.step_pre(step_index=step_index)
torch.cuda.synchronize()
time2 = time.time()
model.infer(inputs)
torch.cuda.synchronize()
time3 = time.time()
model.scheduler.step_post() with ProfilingContext4Debug("step_pre"):
model.scheduler.step_pre(step_index=step_index)
torch.cuda.synchronize() with ProfilingContext4Debug("infer"):
time4 = time.time() model.infer(inputs)
print(f"step {step_index} infer time: {time3 - time2}") with ProfilingContext4Debug("step_post"):
print(f"step {step_index} all time: {time4 - time1}") model.scheduler.step_post()
print("*" * 10)
return model.scheduler.latents, model.scheduler.generator return model.scheduler.latents, model.scheduler.generator
...@@ -344,7 +326,7 @@ if __name__ == "__main__": ...@@ -344,7 +326,7 @@ if __name__ == "__main__":
parser.add_argument("--strength_model", type=float, default=1.0) parser.add_argument("--strength_model", type=float, default=1.0)
args = parser.parse_args() args = parser.parse_args()
start_time = time.time() start_time = time.perf_counter()
print(f"args: {args}") print(f"args: {args}")
seed_all(args.seed) seed_all(args.seed)
...@@ -356,7 +338,7 @@ if __name__ == "__main__": ...@@ -356,7 +338,7 @@ if __name__ == "__main__":
print(f"config: {config}") print(f"config: {config}")
with time_duration("Load models"): with ProfilingContext("Load models"):
model, text_encoders, vae_model, image_encoder = load_models(config) model, text_encoders, vae_model, image_encoder = load_models(config)
if config["task"] in ["i2v"]: if config["task"] in ["i2v"]:
...@@ -364,7 +346,7 @@ if __name__ == "__main__": ...@@ -364,7 +346,7 @@ if __name__ == "__main__":
else: else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None} image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
with time_duration("Run Text Encoder"): with ProfilingContext("Run Text Encoder"):
text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output) text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output)
inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output} inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
...@@ -383,15 +365,15 @@ if __name__ == "__main__": ...@@ -383,15 +365,15 @@ if __name__ == "__main__":
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache() torch.cuda.empty_cache()
with time_duration("Run VAE"): with ProfilingContext("Run VAE"):
images = run_vae(latents, generator, config) images = run_vae(latents, generator, config)
if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0): if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0):
with time_duration("Save video"): with ProfilingContext("Save video"):
if config.model_cls == "wan2.1": if config.model_cls == "wan2.1":
cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else: else:
save_videos_grid(images, config.save_video_path, fps=24) save_videos_grid(images, config.save_video_path, fps=24)
end_time = time.time() end_time = time.perf_counter()
print(f"Total cost: {end_time - start_time}") print(f"Total cost: {end_time - start_time}")
import time
import os
import torch
from contextlib import ContextDecorator
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true"
class _ProfilingContext(ContextDecorator):
def __init__(self, name):
self.name = name
def __enter__(self):
torch.cuda.synchronize()
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
elapsed = time.perf_counter() - self.start_time
print(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False
class _NullContext(ContextDecorator):
# Context manager without decision branch logic overhead
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
return False
ProfilingContext = _ProfilingContext
ProfilingContext4Debug = _ProfilingContext if ENABLE_PROFILING_DEBUG else _NullContext
...@@ -23,6 +23,8 @@ fi ...@@ -23,6 +23,8 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \ --model_cls hunyuan \
--model_path $model_path \ --model_path $model_path \
......
...@@ -23,6 +23,8 @@ fi ...@@ -23,6 +23,8 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \ --model_cls hunyuan \
--model_path $model_path \ --model_path $model_path \
......
...@@ -23,6 +23,7 @@ fi ...@@ -23,6 +23,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \ --model_cls hunyuan \
......
...@@ -23,6 +23,7 @@ fi ...@@ -23,6 +23,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \ --model_cls hunyuan \
......
...@@ -29,6 +29,7 @@ fi ...@@ -29,6 +29,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \ --model_cls wan2.1 \
......
...@@ -26,6 +26,8 @@ fi ...@@ -26,6 +26,8 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python -m lightx2v \ python -m lightx2v \
--model_cls wan2.1 \ --model_cls wan2.1 \
--task i2v \ --task i2v \
......
...@@ -29,6 +29,7 @@ fi ...@@ -29,6 +29,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \ python ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \ --model_cls wan2.1 \
......
...@@ -29,6 +29,7 @@ fi ...@@ -29,6 +29,7 @@ fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--model_cls wan2.1 \ --model_cls wan2.1 \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment