Unverified Commit 891f3bf1 authored by yihuiwen's avatar yihuiwen Committed by GitHub
Browse files

add worker metrics (#343)


Co-authored-by: default avataryihuiwen <yihuiwen@sensetime.com>
parent 5ffdbeb6
......@@ -15,6 +15,7 @@ from loguru import logger
from lightx2v.deploy.data_manager import LocalDataManager, S3DataManager
from lightx2v.deploy.task_manager import TaskStatus
from lightx2v.deploy.worker.hub import DiTWorker, ImageEncoderWorker, PipelineWorker, SegmentDiTWorker, TextEncoderWorker, VaeDecoderWorker, VaeEncoderWorker
from lightx2v.server.metrics import metrics
RUNNER_MAP = {
"pipeline": PipelineWorker,
......@@ -205,6 +206,8 @@ async def main(args):
args.task_name = args.task
worker_keys = [args.task_name, args.model_name, args.stage, args.worker]
metrics.server_process(args.metric_port)
data_manager = None
if args.data_url.startswith("/"):
data_manager = LocalDataManager(args.data_url, None)
......@@ -329,6 +332,8 @@ if __name__ == "__main__":
parser.add_argument("--timeout", type=int, default=300)
parser.add_argument("--ping_interval", type=int, default=10)
parser.add_argument("--metric_port", type=int, default=8001)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
......
......@@ -8,6 +8,7 @@ from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.memory_profiler import peak_memory_decorator
......@@ -167,6 +168,9 @@ class DefaultRunner(BaseRunner):
img_ori = img_path
else:
img_ori = Image.open(img_path).convert("RGB")
if GET_RECORDER_MODE():
width, height = img_ori.size
monitor_cli.lightx2v_input_image_len.observe(width*height)
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
self.input_info.original_size = img_ori.size
return img, img_ori
......@@ -252,7 +256,10 @@ class DefaultRunner(BaseRunner):
self.model.select_graph_for_compile(self.input_info)
for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}", \
recorder_mode=GET_RECORDER_MODE(), \
metrics_func=monitor_cli.lightx2v_run_pre_step_dit_duration, \
metrics_labels=[segment_idx+1, self.video_segment_num]):
self.check_stop()
# 1. default do nothing
self.init_run_segment(segment_idx)
......@@ -266,7 +273,12 @@ class DefaultRunner(BaseRunner):
self.end_run()
return {"video": gen_video_final}
@ProfilingContext4DebugL1("Run VAE Decoder")
@ProfilingContext4DebugL1(
"Run VAE Decoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=labels=["DefaultRunner"]
)
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
......@@ -321,7 +333,15 @@ class DefaultRunner(BaseRunner):
logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
return {"video": None}
@ProfilingContext4DebugL1(
"RUN pipeline",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_worker_request_duration,
metrics_labels=["DefaultRunner"]
)
def run_pipeline(self, input_info):
if GET_RECORDER_MODE():
monitor_cli.lightx2v_worker_request_count.inc()
self.input_info = input_info
if self.config["use_prompt_enhancer"]:
......@@ -331,4 +351,6 @@ class DefaultRunner(BaseRunner):
gen_video_final = self.run_main()
if GET_RECORDER_MODE():
monitor_cli.lightx2v_worker_request_success.inc()
return gen_video_final
......@@ -10,8 +10,10 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.server.metrics import monitor_cli
def calculate_dimensions(target_area, ratio):
......@@ -106,7 +108,15 @@ class QwenImageRunner(DefaultRunner):
"image_encoder_output": image_encoder_output,
}
@ProfilingContext4DebugL1(
"Run Text Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_text_encode_duration,
metrics_labels=["QwenImageRunner"]
)
def run_text_encoder(self, text, image=None):
if GET_RECORDER_MODE():
monitor_cli.lightx2v_input_prompt_len.observe(len(text))
text_encoder_output = {}
if self.config["task"] == "t2i":
prompt_embeds, prompt_embeds_mask, _, _ = self.text_encoders[0].infer([text])
......@@ -120,6 +130,12 @@ class QwenImageRunner(DefaultRunner):
text_encoder_output["image_info"] = image_info
return text_encoder_output
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["QwenImageRunner"]
)
def run_vae_encoder(self, image):
image_latents = self.vae.encode_vae_image(image)
return {"image_latents": image_latents}
......@@ -183,7 +199,12 @@ class QwenImageRunner(DefaultRunner):
self.vae = self.load_vae()
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
@ProfilingContext4DebugL1("Run VAE Decoder")
@ProfilingContext4DebugL1(
"Run VAE Decoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=["QwenImageRunner"],
)
def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae()
......
......@@ -22,6 +22,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
from lightx2v.server.metrics import monitor_cli
@RUNNER_REGISTER("wan2.2_animate")
......@@ -150,6 +151,12 @@ class WanAnimateRunner(WanRunner):
)
return {"image_encoder_output": {"clip_encoder_out": clip_encoder_out, "vae_encoder_out": vae_encoder_out, "pose_latents": pose_latents, "face_pixel_values": face_pixel_values}}
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["WanAnimateRunner"],
)
def run_vae_encoder(
self,
conditioning_pixel_values,
......@@ -269,7 +276,12 @@ class WanAnimateRunner(WanRunner):
self.prepare_input()
super().init_run()
@ProfilingContext4DebugL1("Run VAE Decoder")
@ProfilingContext4DebugL1(
"Run VAE Decoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=["WanAnimateRunner"],
)
def run_vae_decoder(self, latents):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_decoder = self.load_vae_decoder()
......@@ -351,6 +363,12 @@ class WanAnimateRunner(WanRunner):
gc.collect()
super().process_images_after_vae_decoder()
@ProfilingContext4DebugL1(
"Run Image Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
metrics_labels=["WanAnimateRunner"],
)
def run_image_encoder(self, img): # CHW
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.image_encoder = self.load_image_encoder()
......
......@@ -27,6 +27,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import EulerScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
......@@ -359,6 +360,9 @@ class WanAudioRunner(WanRunner): # type:ignore
video_duration = self.config.get("video_duration", 5)
audio_len = int(audio_array.shape[1] / audio_sr * target_fps)
if GET_RECORDER_MODE():
monitor_cli.lightx2v_input_audio_len.observe(audio_len)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
......@@ -447,6 +451,12 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = torch.nn.functional.interpolate(ref_img, size=(target_shape[0], target_shape[1]), mode="bicubic")
return ref_img, latent_shape, target_shape
@ProfilingContext4DebugL1(
"Run Image Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
metrics_labels=["WanAudioRunner"],
)
def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
......@@ -457,6 +467,12 @@ class WanAudioRunner(WanRunner): # type:ignore
gc.collect()
return clip_encoder_out
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["WanAudioRunner"],
)
def run_vae_encoder(self, img):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder()
......
......@@ -29,6 +29,7 @@ from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size
from lightx2v.server.metrics import monitor_cli
@RUNNER_REGISTER("wan2.1")
......@@ -206,11 +207,19 @@ class WanRunner(DefaultRunner):
else:
self.scheduler = scheduler_class(self.config)
@ProfilingContext4DebugL1(
"Run Text Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_text_encode_duration,
metrics_labels=["WanRunner"],
)
def run_text_encoder(self, input_info):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder()
prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt
if GET_RECORDER_MODE():
monitor_cli.lightx2v_input_prompt_len.observe(len(prompt))
neg_prompt = input_info.negative_prompt
if self.config["cfg_parallel"]:
......@@ -241,6 +250,12 @@ class WanRunner(DefaultRunner):
return text_encoder_output
@ProfilingContext4DebugL1(
"Run Image Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
metrics_labels=["WanRunner"],
)
def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
......@@ -254,6 +269,12 @@ class WanRunner(DefaultRunner):
gc.collect()
return clip_encoder_out
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["WanRunner"],
)
def run_vae_encoder(self, first_frame, last_frame=None):
h, w = first_frame.shape[2:]
aspect_ratio = h / w
......@@ -469,6 +490,12 @@ class Wan22DenseRunner(WanRunner):
self.vae_name = "Wan2.2_VAE.pth"
self.tiny_vae_name = "taew2_2.pth"
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["Wan22DenseRunner"],
)
def run_vae_encoder(self, img):
max_area = self.config.target_height * self.config.target_width
ih, iw = img.height, img.width
......
......@@ -11,6 +11,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.server.metrics import monitor_cli
@RUNNER_REGISTER("wan2.1_vace")
......@@ -88,6 +89,12 @@ class WanVaceRunner(WanRunner):
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
metrics_labels=["WanVaceRunner"],
)
def run_vae_encoder(self, frames, ref_images, masks):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_encoder = self.load_vae_encoder()
......@@ -159,7 +166,12 @@ class WanVaceRunner(WanRunner):
latent_shape[0] = int(latent_shape[0] / 2)
return latent_shape
@ProfilingContext4DebugL1("Run VAE Decoder")
@ProfilingContext4DebugL1(
"Run VAE Decoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=["WanVaceRunner"],
)
def run_vae_decoder(self, latents):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_decoder = self.load_vae_decoder()
......
# -*-coding=utf-8-*-
from .metrics import server_process
from .monitor import Monitor
monitor_cli = Monitor()
# -*-coding=utf-8-*-
from prometheus_client import start_http_server, Counter, Gauge, Histogram
from prometheus_client.metrics import MetricWrapperBase
import time
from loguru import logger
import threading
from pydantic import BaseModel
from typing import Optional, List, Dict, Any, Tuple
from functools import wraps
class MetricsConfig(BaseModel):
name: str
desc: str
type_: str
labels: List[str] = []
buckets: Tuple[float, ...] = (
0.1,
0.5,
1.0,
2.5,
5.0,
10.0,
30.0,
60.0,
120.0,
300.0,
600.0,
)
METRICS_INFO = {
"lightx2v_worker_request_count": MetricsConfig(
name="lightx2v_worker_request_count",
desc="The total number of requests",
type_="counter",
),
"lightx2v_worker_request_success": MetricsConfig(
name="lightx2v_worker_request_success",
desc="The number of successful requests",
type_="counter",
),
"lightx2v_worker_request_failure": MetricsConfig(
name="lightx2v_worker_request_failure",
desc="The number of failed requests",
type_="counter",
labels=["error_type"],
),
"lightx2v_worker_request_duration": MetricsConfig(
name="lightx2v_worker_request_duration",
desc="Duration of the request (s)",
type_="histogram",
),
"lightx2v_input_audio_len": MetricsConfig(
name="lightx2v_input_audio_len",
desc="Length of the input audio",
type_="histogram",
buckets=(
1.0,
2.0,
3.0,
5.0,
7.0,
10.0,
20.0,
30.0,
45.0,
60.0,
75.0,
90.0,
105.0,
120.0,
),
),
"lightx2v_input_image_len": MetricsConfig(
name="lightx2v_input_image_len",
desc="Length of the input image",
type_="histogram",
),
"lightx2v_input_prompt_len": MetricsConfig(
name="lightx2v_input_prompt_len",
desc="Length of the input prompt",
type_="histogram",
),
"lightx2v_load_model_duration": MetricsConfig(
name="lightx2v_load_model_duration",
desc="Duration of load model (s)",
type_="histogram",
),
"lightx2v_run_per_step_dit_duration": MetricsConfig(
name="lightx2v_run_pre_step_dit_duration",
desc="Duration of run per step Dit (s)",
type_="histogram",
labels=["step_no", "total_steps"],
),
"lightx2v_run_text_encode_duration": MetricsConfig(
name="lightx2v_run_text_encode_duration",
desc="Duration of run text encode (s)",
type_="histogram",
labels=["model_cls"],
),
"lightx2v_run_img_encode_duration": MetricsConfig(
name="lightx2v_run_img_encode_duration",
desc="Duration of run img encode (s)",
type_="histogram",
labels=["model_cls"],
),
"lightx2v_run_vae_encode_duration": MetricsConfig(
name="lightx2v_run_vae_encode_duration",
desc="Duration of run vae encode (s)",
type_="histogram",
labels=["model_cls"],
),
"lightx2v_run_vae_decode_duration": MetricsConfig(
name="lightx2v_run_vae_decode_duration",
desc="Duration of run vae decode (s)",
type_="histogram",
labels=["model_cls"],
),
}
class MetricsClient:
def __init__(self):
self.init_metrics()
def init_metrics(self):
for metric_name, config in METRICS_INFO.items():
if config.type_ == "counter":
self.register_counter(config.name, config.desc, config.labels)
elif config.type_ == "histogram":
self.register_histogram(
config.name, config.desc, config.labels, buckets=config.buckets
)
elif config.type_ == "gauge":
self.register_gauge(config.name, config.desc, config.labels)
else:
logger.warning(
f"Unsupported metric type: {config.type_} for {metric_name}"
)
def register_counter(self, name, desc, labels):
metric_instance = Counter(name, desc, labels)
setattr(self, name, metric_instance)
def register_histogram(self, name, desc, labels, buckets=None):
buckets = buckets or (
0.1,
0.5,
1.0,
2.5,
5.0,
10.0,
30.0,
60.0,
120.0,
300.0,
600.0,
)
metric_instance = Histogram(name, desc, labels, buckets=buckets)
setattr(self, name, metric_instance)
def register_gauge(self, name, desc, labels):
metric_instance = Gauge(name, desc, labels)
setattr(self, name, metric_instance)
class MetricsServer:
def __init__(self, port=8000):
self.port = port
self.server_thread = None
def start_server(self):
def run_server():
start_http_server(self.port)
logger.info(f"Metrics server started on port {self.port}")
self.server_thread = threading.Thread(target=run_server)
self.server_thread.daemon = True
self.server_thread.start()
def server_process(metric_port=8001):
metrics = MetricsServer(
port=metric_port,
)
metrics.start_server()
# -*-coding=utf-8-*-
import threading
from .metrics import MetricsClient
class Monitor(MetricsClient):
_instance = None
_lock = threading.Lock()
_initialized = False # 添加初始化标志
def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, *args, **kwargs):
if not self.__class__._initialized:
super().__init__(*args, **kwargs)
self.__class__._initialized = True
......@@ -41,3 +41,10 @@ def GET_SENSITIVE_DTYPE():
if RUNNING_FLAG == "None":
return GET_DTYPE()
return DTYPE_MAP[RUNNING_FLAG]
@lru_cache(maxsize=None)
def GET_RECORDER_MODE():
RECORDER_MODE = int(os.getenv("RECORDER_MODE", "0"))
return RECORDER_MODE
......@@ -10,12 +10,21 @@ from lightx2v.utils.envs import *
class _ProfilingContext:
def __init__(self, name):
def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None):
"""
recorder_mode = 0: disable recorder
recorder_mode = 1: enable recorder
recorder_mode = 2: enable recorder and force disable logger
"""
self.name = name
if dist.is_initialized():
self.rank_info = f"Rank {dist.get_rank()}"
else:
self.rank_info = "Single GPU"
self.enable_recorder = recorder_mode > 0
self.enable_logger = recorder_mode <= 1
self.metrics_func = metrics_func
self.metrics_labels = metrics_labels
def __enter__(self):
torch.cuda.synchronize()
......@@ -25,6 +34,12 @@ class _ProfilingContext:
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
elapsed = time.perf_counter() - self.start_time
if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
metrics_func.labels(self.metrics_labels).observe(elapsed)
else:
metrics_func.observe(elapsed)
if self.enable_logger:
logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
return False
......@@ -36,6 +51,12 @@ class _ProfilingContext:
async def __aexit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
elapsed = time.perf_counter() - self.start_time
if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
metrics_func.labels(self.metrics_labels).observe(elapsed)
else:
metrics_func.observe(elapsed)
if self.enable_logger:
logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment