Unverified Commit 891f3bf1 authored by yihuiwen's avatar yihuiwen Committed by GitHub
Browse files

add worker metrics (#343)


Co-authored-by: default avataryihuiwen <yihuiwen@sensetime.com>
parent 5ffdbeb6
...@@ -15,6 +15,7 @@ from loguru import logger ...@@ -15,6 +15,7 @@ from loguru import logger
from lightx2v.deploy.data_manager import LocalDataManager, S3DataManager from lightx2v.deploy.data_manager import LocalDataManager, S3DataManager
from lightx2v.deploy.task_manager import TaskStatus from lightx2v.deploy.task_manager import TaskStatus
from lightx2v.deploy.worker.hub import DiTWorker, ImageEncoderWorker, PipelineWorker, SegmentDiTWorker, TextEncoderWorker, VaeDecoderWorker, VaeEncoderWorker from lightx2v.deploy.worker.hub import DiTWorker, ImageEncoderWorker, PipelineWorker, SegmentDiTWorker, TextEncoderWorker, VaeDecoderWorker, VaeEncoderWorker
from lightx2v.server.metrics import metrics
RUNNER_MAP = { RUNNER_MAP = {
"pipeline": PipelineWorker, "pipeline": PipelineWorker,
...@@ -205,6 +206,8 @@ async def main(args): ...@@ -205,6 +206,8 @@ async def main(args):
args.task_name = args.task args.task_name = args.task
worker_keys = [args.task_name, args.model_name, args.stage, args.worker] worker_keys = [args.task_name, args.model_name, args.stage, args.worker]
metrics.server_process(args.metric_port)
data_manager = None data_manager = None
if args.data_url.startswith("/"): if args.data_url.startswith("/"):
data_manager = LocalDataManager(args.data_url, None) data_manager = LocalDataManager(args.data_url, None)
...@@ -329,6 +332,8 @@ if __name__ == "__main__": ...@@ -329,6 +332,8 @@ if __name__ == "__main__":
parser.add_argument("--timeout", type=int, default=300) parser.add_argument("--timeout", type=int, default=300)
parser.add_argument("--ping_interval", type=int, default=10) parser.add_argument("--ping_interval", type=int, default=10)
parser.add_argument("--metric_port", type=int, default=8001)
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
......
...@@ -8,6 +8,7 @@ from PIL import Image ...@@ -8,6 +8,7 @@ from PIL import Image
from loguru import logger from loguru import logger
from requests.exceptions import RequestException from requests.exceptions import RequestException
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.memory_profiler import peak_memory_decorator from lightx2v.utils.memory_profiler import peak_memory_decorator
...@@ -167,6 +168,9 @@ class DefaultRunner(BaseRunner): ...@@ -167,6 +168,9 @@ class DefaultRunner(BaseRunner):
img_ori = img_path img_ori = img_path
else: else:
img_ori = Image.open(img_path).convert("RGB") img_ori = Image.open(img_path).convert("RGB")
if GET_RECORDER_MODE():
width, height = img_ori.size
monitor_cli.lightx2v_input_image_len.observe(width*height)
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda() img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
self.input_info.original_size = img_ori.size self.input_info.original_size = img_ori.size
return img, img_ori return img, img_ori
...@@ -252,7 +256,10 @@ class DefaultRunner(BaseRunner): ...@@ -252,7 +256,10 @@ class DefaultRunner(BaseRunner):
self.model.select_graph_for_compile(self.input_info) self.model.select_graph_for_compile(self.input_info)
for segment_idx in range(self.video_segment_num): for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}") logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"): with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}", \
recorder_mode=GET_RECORDER_MODE(), \
metrics_func=monitor_cli.lightx2v_run_pre_step_dit_duration, \
metrics_labels=[segment_idx+1, self.video_segment_num]):
self.check_stop() self.check_stop()
# 1. default do nothing # 1. default do nothing
self.init_run_segment(segment_idx) self.init_run_segment(segment_idx)
...@@ -266,7 +273,12 @@ class DefaultRunner(BaseRunner): ...@@ -266,7 +273,12 @@ class DefaultRunner(BaseRunner):
self.end_run() self.end_run()
return {"video": gen_video_final} return {"video": gen_video_final}
@ProfilingContext4DebugL1("Run VAE Decoder") @ProfilingContext4DebugL1(
"Run VAE Decoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=labels=["DefaultRunner"]
)
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
...@@ -321,7 +333,15 @@ class DefaultRunner(BaseRunner): ...@@ -321,7 +333,15 @@ class DefaultRunner(BaseRunner):
logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅") logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
return {"video": None} return {"video": None}
@ProfilingContext4DebugL1(
"RUN pipeline",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_worker_request_duration,
metrics_labels=["DefaultRunner"]
)
def run_pipeline(self, input_info): def run_pipeline(self, input_info):
if GET_RECORDER_MODE():
monitor_cli.lightx2v_worker_request_count.inc()
self.input_info = input_info self.input_info = input_info
if self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
...@@ -331,4 +351,6 @@ class DefaultRunner(BaseRunner): ...@@ -331,4 +351,6 @@ class DefaultRunner(BaseRunner):
gen_video_final = self.run_main() gen_video_final = self.run_main()
if GET_RECORDER_MODE():
monitor_cli.lightx2v_worker_request_success.inc()
return gen_video_final return gen_video_final
...@@ -10,8 +10,10 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel ...@@ -10,8 +10,10 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.server.metrics import monitor_cli
def calculate_dimensions(target_area, ratio): def calculate_dimensions(target_area, ratio):
...@@ -106,7 +108,15 @@ class QwenImageRunner(DefaultRunner): ...@@ -106,7 +108,15 @@ class QwenImageRunner(DefaultRunner):
"image_encoder_output": image_encoder_output, "image_encoder_output": image_encoder_output,
} }
@ProfilingContext4DebugL1(
"Run Text Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_text_encode_duration,
metrics_labels=["QwenImageRunner"]
)
def run_text_encoder(self, text, image=None): def run_text_encoder(self, text, image=None):
if GET_RECORDER_MODE():
monitor_cli.lightx2v_input_prompt_len.observe(len(text))
text_encoder_output = {} text_encoder_output = {}
if self.config["task"] == "t2i": if self.config["task"] == "t2i":
prompt_embeds, prompt_embeds_mask, _, _ = self.text_encoders[0].infer([text]) prompt_embeds, prompt_embeds_mask, _, _ = self.text_encoders[0].infer([text])
...@@ -120,6 +130,12 @@ class QwenImageRunner(DefaultRunner): ...@@ -120,6 +130,12 @@ class QwenImageRunner(DefaultRunner):
text_encoder_output["image_info"] = image_info text_encoder_output["image_info"] = image_info
return text_encoder_output return text_encoder_output
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["QwenImageRunner"]
)
def run_vae_encoder(self, image): def run_vae_encoder(self, image):
image_latents = self.vae.encode_vae_image(image) image_latents = self.vae.encode_vae_image(image)
return {"image_latents": image_latents} return {"image_latents": image_latents}
...@@ -183,7 +199,12 @@ class QwenImageRunner(DefaultRunner): ...@@ -183,7 +199,12 @@ class QwenImageRunner(DefaultRunner):
self.vae = self.load_vae() self.vae = self.load_vae()
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
@ProfilingContext4DebugL1("Run VAE Decoder") @ProfilingContext4DebugL1(
"Run VAE Decoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=["QwenImageRunner"],
)
def _run_vae_decoder_local(self, latents, generator): def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae() self.vae_decoder = self.load_vae()
......
...@@ -22,6 +22,7 @@ from lightx2v.utils.envs import * ...@@ -22,6 +22,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, remove_substrings_from_keys from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
from lightx2v.server.metrics import monitor_cli
@RUNNER_REGISTER("wan2.2_animate") @RUNNER_REGISTER("wan2.2_animate")
...@@ -150,6 +151,12 @@ class WanAnimateRunner(WanRunner): ...@@ -150,6 +151,12 @@ class WanAnimateRunner(WanRunner):
) )
return {"image_encoder_output": {"clip_encoder_out": clip_encoder_out, "vae_encoder_out": vae_encoder_out, "pose_latents": pose_latents, "face_pixel_values": face_pixel_values}} return {"image_encoder_output": {"clip_encoder_out": clip_encoder_out, "vae_encoder_out": vae_encoder_out, "pose_latents": pose_latents, "face_pixel_values": face_pixel_values}}
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["WanAnimateRunner"],
)
def run_vae_encoder( def run_vae_encoder(
self, self,
conditioning_pixel_values, conditioning_pixel_values,
...@@ -269,7 +276,12 @@ class WanAnimateRunner(WanRunner): ...@@ -269,7 +276,12 @@ class WanAnimateRunner(WanRunner):
self.prepare_input() self.prepare_input()
super().init_run() super().init_run()
@ProfilingContext4DebugL1("Run VAE Decoder") @ProfilingContext4DebugL1(
"Run VAE Decoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=["WanAnimateRunner"],
)
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
...@@ -351,6 +363,12 @@ class WanAnimateRunner(WanRunner): ...@@ -351,6 +363,12 @@ class WanAnimateRunner(WanRunner):
gc.collect() gc.collect()
super().process_images_after_vae_decoder() super().process_images_after_vae_decoder()
@ProfilingContext4DebugL1(
"Run Image Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
metrics_labels=["WanAnimateRunner"],
)
def run_image_encoder(self, img): # CHW def run_image_encoder(self, img): # CHW
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
......
...@@ -27,6 +27,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper ...@@ -27,6 +27,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import EulerScheduler from lightx2v.models.schedulers.wan.audio.scheduler import EulerScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -359,6 +360,9 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -359,6 +360,9 @@ class WanAudioRunner(WanRunner): # type:ignore
video_duration = self.config.get("video_duration", 5) video_duration = self.config.get("video_duration", 5)
audio_len = int(audio_array.shape[1] / audio_sr * target_fps) audio_len = int(audio_array.shape[1] / audio_sr * target_fps)
if GET_RECORDER_MODE():
monitor_cli.lightx2v_input_audio_len.observe(audio_len)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len) expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio # Segment audio
...@@ -447,6 +451,12 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -447,6 +451,12 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = torch.nn.functional.interpolate(ref_img, size=(target_shape[0], target_shape[1]), mode="bicubic") ref_img = torch.nn.functional.interpolate(ref_img, size=(target_shape[0], target_shape[1]), mode="bicubic")
return ref_img, latent_shape, target_shape return ref_img, latent_shape, target_shape
@ProfilingContext4DebugL1(
"Run Image Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
metrics_labels=["WanAudioRunner"],
)
def run_image_encoder(self, first_frame, last_frame=None): def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
...@@ -457,6 +467,12 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -457,6 +467,12 @@ class WanAudioRunner(WanRunner): # type:ignore
gc.collect() gc.collect()
return clip_encoder_out return clip_encoder_out
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["WanAudioRunner"],
)
def run_vae_encoder(self, img): def run_vae_encoder(self, img):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
......
...@@ -29,6 +29,7 @@ from lightx2v.utils.profiler import * ...@@ -29,6 +29,7 @@ from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import * from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size from lightx2v.utils.utils import best_output_size
from lightx2v.server.metrics import monitor_cli
@RUNNER_REGISTER("wan2.1") @RUNNER_REGISTER("wan2.1")
...@@ -206,11 +207,19 @@ class WanRunner(DefaultRunner): ...@@ -206,11 +207,19 @@ class WanRunner(DefaultRunner):
else: else:
self.scheduler = scheduler_class(self.config) self.scheduler = scheduler_class(self.config)
@ProfilingContext4DebugL1(
"Run Text Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_text_encode_duration,
metrics_labels=["WanRunner"],
)
def run_text_encoder(self, input_info): def run_text_encoder(self, input_info):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder() self.text_encoders = self.load_text_encoder()
prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt
if GET_RECORDER_MODE():
monitor_cli.lightx2v_input_prompt_len.observe(len(prompt))
neg_prompt = input_info.negative_prompt neg_prompt = input_info.negative_prompt
if self.config["cfg_parallel"]: if self.config["cfg_parallel"]:
...@@ -241,6 +250,12 @@ class WanRunner(DefaultRunner): ...@@ -241,6 +250,12 @@ class WanRunner(DefaultRunner):
return text_encoder_output return text_encoder_output
@ProfilingContext4DebugL1(
"Run Image Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
metrics_labels=["WanRunner"],
)
def run_image_encoder(self, first_frame, last_frame=None): def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
...@@ -254,6 +269,12 @@ class WanRunner(DefaultRunner): ...@@ -254,6 +269,12 @@ class WanRunner(DefaultRunner):
gc.collect() gc.collect()
return clip_encoder_out return clip_encoder_out
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["WanRunner"],
)
def run_vae_encoder(self, first_frame, last_frame=None): def run_vae_encoder(self, first_frame, last_frame=None):
h, w = first_frame.shape[2:] h, w = first_frame.shape[2:]
aspect_ratio = h / w aspect_ratio = h / w
...@@ -469,6 +490,12 @@ class Wan22DenseRunner(WanRunner): ...@@ -469,6 +490,12 @@ class Wan22DenseRunner(WanRunner):
self.vae_name = "Wan2.2_VAE.pth" self.vae_name = "Wan2.2_VAE.pth"
self.tiny_vae_name = "taew2_2.pth" self.tiny_vae_name = "taew2_2.pth"
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["Wan22DenseRunner"],
)
def run_vae_encoder(self, img): def run_vae_encoder(self, img):
max_area = self.config.target_height * self.config.target_width max_area = self.config.target_height * self.config.target_width
ih, iw = img.height, img.width ih, iw = img.height, img.width
......
...@@ -11,6 +11,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner ...@@ -11,6 +11,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.server.metrics import monitor_cli
@RUNNER_REGISTER("wan2.1_vace") @RUNNER_REGISTER("wan2.1_vace")
...@@ -88,6 +89,12 @@ class WanVaceRunner(WanRunner): ...@@ -88,6 +89,12 @@ class WanVaceRunner(WanRunner):
src_ref_images[i][j] = ref_img.to(device) src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images return src_video, src_mask, src_ref_images
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["WanVaceRunner"],
)
def run_vae_encoder(self, frames, ref_images, masks): def run_vae_encoder(self, frames, ref_images, masks):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
...@@ -159,7 +166,12 @@ class WanVaceRunner(WanRunner): ...@@ -159,7 +166,12 @@ class WanVaceRunner(WanRunner):
latent_shape[0] = int(latent_shape[0] / 2) latent_shape[0] = int(latent_shape[0] / 2)
return latent_shape return latent_shape
@ProfilingContext4DebugL1("Run VAE Decoder") @ProfilingContext4DebugL1(
"Run VAE Decoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=["WanVaceRunner"],
)
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
......
# -*-coding=utf-8-*-
from .metrics import server_process
from .monitor import Monitor
monitor_cli = Monitor()
# -*-coding=utf-8-*-
from prometheus_client import start_http_server, Counter, Gauge, Histogram
from prometheus_client.metrics import MetricWrapperBase
import time
from loguru import logger
import threading
from pydantic import BaseModel
from typing import Optional, List, Dict, Any, Tuple
from functools import wraps
class MetricsConfig(BaseModel):
name: str
desc: str
type_: str
labels: List[str] = []
buckets: Tuple[float, ...] = (
0.1,
0.5,
1.0,
2.5,
5.0,
10.0,
30.0,
60.0,
120.0,
300.0,
600.0,
)
METRICS_INFO = {
"lightx2v_worker_request_count": MetricsConfig(
name="lightx2v_worker_request_count",
desc="The total number of requests",
type_="counter",
),
"lightx2v_worker_request_success": MetricsConfig(
name="lightx2v_worker_request_success",
desc="The number of successful requests",
type_="counter",
),
"lightx2v_worker_request_failure": MetricsConfig(
name="lightx2v_worker_request_failure",
desc="The number of failed requests",
type_="counter",
labels=["error_type"],
),
"lightx2v_worker_request_duration": MetricsConfig(
name="lightx2v_worker_request_duration",
desc="Duration of the request (s)",
type_="histogram",
),
"lightx2v_input_audio_len": MetricsConfig(
name="lightx2v_input_audio_len",
desc="Length of the input audio",
type_="histogram",
buckets=(
1.0,
2.0,
3.0,
5.0,
7.0,
10.0,
20.0,
30.0,
45.0,
60.0,
75.0,
90.0,
105.0,
120.0,
),
),
"lightx2v_input_image_len": MetricsConfig(
name="lightx2v_input_image_len",
desc="Length of the input image",
type_="histogram",
),
"lightx2v_input_prompt_len": MetricsConfig(
name="lightx2v_input_prompt_len",
desc="Length of the input prompt",
type_="histogram",
),
"lightx2v_load_model_duration": MetricsConfig(
name="lightx2v_load_model_duration",
desc="Duration of load model (s)",
type_="histogram",
),
"lightx2v_run_per_step_dit_duration": MetricsConfig(
name="lightx2v_run_pre_step_dit_duration",
desc="Duration of run per step Dit (s)",
type_="histogram",
labels=["step_no", "total_steps"],
),
"lightx2v_run_text_encode_duration": MetricsConfig(
name="lightx2v_run_text_encode_duration",
desc="Duration of run text encode (s)",
type_="histogram",
labels=["model_cls"],
),
"lightx2v_run_img_encode_duration": MetricsConfig(
name="lightx2v_run_img_encode_duration",
desc="Duration of run img encode (s)",
type_="histogram",
labels=["model_cls"],
),
"lightx2v_run_vae_encode_duration": MetricsConfig(
name="lightx2v_run_vae_encode_duration",
desc="Duration of run vae encode (s)",
type_="histogram",
labels=["model_cls"],
),
"lightx2v_run_vae_decode_duration": MetricsConfig(
name="lightx2v_run_vae_decode_duration",
desc="Duration of run vae decode (s)",
type_="histogram",
labels=["model_cls"],
),
}
class MetricsClient:
def __init__(self):
self.init_metrics()
def init_metrics(self):
for metric_name, config in METRICS_INFO.items():
if config.type_ == "counter":
self.register_counter(config.name, config.desc, config.labels)
elif config.type_ == "histogram":
self.register_histogram(
config.name, config.desc, config.labels, buckets=config.buckets
)
elif config.type_ == "gauge":
self.register_gauge(config.name, config.desc, config.labels)
else:
logger.warning(
f"Unsupported metric type: {config.type_} for {metric_name}"
)
def register_counter(self, name, desc, labels):
metric_instance = Counter(name, desc, labels)
setattr(self, name, metric_instance)
def register_histogram(self, name, desc, labels, buckets=None):
buckets = buckets or (
0.1,
0.5,
1.0,
2.5,
5.0,
10.0,
30.0,
60.0,
120.0,
300.0,
600.0,
)
metric_instance = Histogram(name, desc, labels, buckets=buckets)
setattr(self, name, metric_instance)
def register_gauge(self, name, desc, labels):
metric_instance = Gauge(name, desc, labels)
setattr(self, name, metric_instance)
class MetricsServer:
def __init__(self, port=8000):
self.port = port
self.server_thread = None
def start_server(self):
def run_server():
start_http_server(self.port)
logger.info(f"Metrics server started on port {self.port}")
self.server_thread = threading.Thread(target=run_server)
self.server_thread.daemon = True
self.server_thread.start()
def server_process(metric_port=8001):
metrics = MetricsServer(
port=metric_port,
)
metrics.start_server()
# -*-coding=utf-8-*-
import threading
from .metrics import MetricsClient
class Monitor(MetricsClient):
_instance = None
_lock = threading.Lock()
_initialized = False # 添加初始化标志
def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, *args, **kwargs):
if not self.__class__._initialized:
super().__init__(*args, **kwargs)
self.__class__._initialized = True
...@@ -41,3 +41,10 @@ def GET_SENSITIVE_DTYPE(): ...@@ -41,3 +41,10 @@ def GET_SENSITIVE_DTYPE():
if RUNNING_FLAG == "None": if RUNNING_FLAG == "None":
return GET_DTYPE() return GET_DTYPE()
return DTYPE_MAP[RUNNING_FLAG] return DTYPE_MAP[RUNNING_FLAG]
@lru_cache(maxsize=None)
def GET_RECORDER_MODE():
RECORDER_MODE = int(os.getenv("RECORDER_MODE", "0"))
return RECORDER_MODE
...@@ -10,12 +10,21 @@ from lightx2v.utils.envs import * ...@@ -10,12 +10,21 @@ from lightx2v.utils.envs import *
class _ProfilingContext: class _ProfilingContext:
def __init__(self, name): def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None):
"""
recorder_mode = 0: disable recorder
recorder_mode = 1: enable recorder
recorder_mode = 2: enable recorder and force disable logger
"""
self.name = name self.name = name
if dist.is_initialized(): if dist.is_initialized():
self.rank_info = f"Rank {dist.get_rank()}" self.rank_info = f"Rank {dist.get_rank()}"
else: else:
self.rank_info = "Single GPU" self.rank_info = "Single GPU"
self.enable_recorder = recorder_mode > 0
self.enable_logger = recorder_mode <= 1
self.metrics_func = metrics_func
self.metrics_labels = metrics_labels
def __enter__(self): def __enter__(self):
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -25,7 +34,13 @@ class _ProfilingContext: ...@@ -25,7 +34,13 @@ class _ProfilingContext:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize() torch.cuda.synchronize()
elapsed = time.perf_counter() - self.start_time elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds") if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
metrics_func.labels(self.metrics_labels).observe(elapsed)
else:
metrics_func.observe(elapsed)
if self.enable_logger:
logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
return False return False
async def __aenter__(self): async def __aenter__(self):
...@@ -36,7 +51,13 @@ class _ProfilingContext: ...@@ -36,7 +51,13 @@ class _ProfilingContext:
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize() torch.cuda.synchronize()
elapsed = time.perf_counter() - self.start_time elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds") if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
metrics_func.labels(self.metrics_labels).observe(elapsed)
else:
metrics_func.observe(elapsed)
if self.enable_logger:
logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
return False return False
def __call__(self, func): def __call__(self, func):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment