Unverified Commit f600866a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve the metrics for PD (#12580)


Co-authored-by: default avatarKan Wu <wukanustc@gmail.com>
Co-authored-by: default avatarcctry <shiyang@x.ai>
parent 93be7e86
...@@ -234,6 +234,7 @@ class ModelConfig: ...@@ -234,6 +234,7 @@ class ModelConfig:
model_impl=server_args.model_impl, model_impl=server_args.model_impl,
sampling_defaults=server_args.sampling_defaults, sampling_defaults=server_args.sampling_defaults,
quantize_and_serve=server_args.quantize_and_serve, quantize_and_serve=server_args.quantize_and_serve,
override_config_file=server_args.decrypted_config_file,
**kwargs, **kwargs,
) )
......
...@@ -439,13 +439,9 @@ class Scheduler( ...@@ -439,13 +439,9 @@ class Scheduler(
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_prefill_tokens = 0 self.last_prefill_tokens = 0
self.last_decode_stats_tic = time.perf_counter()
self.last_prefill_stats_tic = time.perf_counter()
self.return_health_check_ct = 0 self.return_health_check_ct = 0
self.num_retracted_reqs: int = 0 self.num_retracted_reqs: int = 0
self.num_paused_reqs: int = 0 self.num_paused_reqs: int = 0
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.sessions: Dict[str, Session] = {} self.sessions: Dict[str, Session] = {}
self.default_stream: CudaStream = torch.get_device_module( self.default_stream: CudaStream = torch.get_device_module(
self.device self.device
...@@ -1426,7 +1422,8 @@ class Scheduler( ...@@ -1426,7 +1422,8 @@ class Scheduler(
def _add_request_to_queue(self, req: Req, is_retracted: bool = False): def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.NULL: if self.disaggregation_mode == DisaggregationMode.NULL:
self._set_or_validate_priority(req) if not self._set_or_validate_priority(req):
return
if self._abort_on_queued_limit(req): if self._abort_on_queued_limit(req):
return return
self._prefetch_kvcache(req) self._prefetch_kvcache(req)
...@@ -1446,7 +1443,7 @@ class Scheduler( ...@@ -1446,7 +1443,7 @@ class Scheduler(
else: else:
raise ValueError(f"Invalid {self.disaggregation_mode=}") raise ValueError(f"Invalid {self.disaggregation_mode=}")
def _set_or_validate_priority(self, req: Req): def _set_or_validate_priority(self, req: Req) -> bool:
"""Set the default priority value, or abort the request based on the priority scheduling mode.""" """Set the default priority value, or abort the request based on the priority scheduling mode."""
if self.enable_priority_scheduling and req.priority is None: if self.enable_priority_scheduling and req.priority is None:
if self.schedule_low_priority_values_first: if self.schedule_low_priority_values_first:
...@@ -1467,6 +1464,8 @@ class Scheduler( ...@@ -1467,6 +1464,8 @@ class Scheduler(
rid=req.rid, rid=req.rid,
) )
self.send_to_tokenizer.send_output(abort_req, req) self.send_to_tokenizer.send_output(abort_req, req)
return False
return True
def _abort_on_queued_limit(self, recv_req: Req) -> bool: def _abort_on_queued_limit(self, recv_req: Req) -> bool:
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted.""" """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
......
...@@ -38,6 +38,9 @@ class SchedulerMetricsMixin: ...@@ -38,6 +38,9 @@ class SchedulerMetricsMixin:
def init_metrics( def init_metrics(
self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int] self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
): ):
self.last_decode_stats_tic = time.perf_counter()
self.last_prefill_stats_tic = time.perf_counter()
self.last_gen_throughput: float = 0.0 self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0 self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
...@@ -50,6 +53,8 @@ class SchedulerMetricsMixin: ...@@ -50,6 +53,8 @@ class SchedulerMetricsMixin:
self.spec_total_num_forward_ct = 0 self.spec_total_num_forward_ct = 0
self.kv_transfer_speed_gb_s: float = 0.0 self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0 self.kv_transfer_latency_ms: float = 0.0
self.kv_transfer_bootstrap_ms: float = 0.0
self.kv_transfer_alloc_ms: float = 0.0
self.stats = SchedulerStats() self.stats = SchedulerStats()
...@@ -178,6 +183,8 @@ class SchedulerMetricsMixin: ...@@ -178,6 +183,8 @@ class SchedulerMetricsMixin:
) )
self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
self.stats.kv_transfer_bootstrap_ms = self.kv_transfer_bootstrap_ms
self.stats.kv_transfer_alloc_ms = self.kv_transfer_alloc_ms
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.stats.num_decode_prealloc_queue_reqs = len( self.stats.num_decode_prealloc_queue_reqs = len(
self.disagg_decode_prealloc_queue.queue self.disagg_decode_prealloc_queue.queue
......
...@@ -43,6 +43,9 @@ class TimeStats: ...@@ -43,6 +43,9 @@ class TimeStats:
prefill_transfer_queue_entry_time: float = 0.0 prefill_transfer_queue_entry_time: float = 0.0
decode_prealloc_queue_entry_time: float = 0.0 decode_prealloc_queue_entry_time: float = 0.0
decode_transfer_queue_entry_time: float = 0.0 decode_transfer_queue_entry_time: float = 0.0
# TODO: correct set them
bootstrap_duration: float = 0.0
alloc_waiting_duration: float = 0.0
def get_queueing_time(self) -> float: def get_queueing_time(self) -> float:
return self.forward_entry_time - self.wait_queue_entry_time return self.forward_entry_time - self.wait_queue_entry_time
...@@ -73,7 +76,20 @@ class TimeStats: ...@@ -73,7 +76,20 @@ class TimeStats:
and forward_duration >= 0 and forward_duration >= 0
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time:.3f}" other = max(
0.0,
bootstrap_duration
- (self.alloc_waiting_duration + self.bootstrap_duration),
)
return (
f"bootstrap_queue_duration({self.format_duration(bootstrap_duration)}) "
f"= alloc_wait({self.format_duration(self.alloc_waiting_duration)}) "
f"+ bootstrap({self.format_duration(self.bootstrap_duration)}) "
f"+ other({self.format_duration(other)}); "
f"queue_duration={self.format_duration(queue_duration)}, "
f"forward_duration={self.format_duration(forward_duration)}, "
f"start={self.prefill_bootstrap_queue_entry_time:.3f}"
)
elif self.disagg_mode == DisaggregationMode.DECODE: elif self.disagg_mode == DisaggregationMode.DECODE:
prealloc_duration = ( prealloc_duration = (
self.decode_transfer_queue_entry_time self.decode_transfer_queue_entry_time
...@@ -94,7 +110,21 @@ class TimeStats: ...@@ -94,7 +110,21 @@ class TimeStats:
and forward_duration >= 0 and forward_duration >= 0
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}" ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}"
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time:.3f}" other = max(
0.0,
prealloc_duration
- (self.alloc_waiting_duration + self.bootstrap_duration),
)
return (
f"prealloc_queue_duration({self.format_duration(prealloc_duration)}) "
f"= alloc_wait({self.format_duration(self.alloc_waiting_duration)}) "
f"+ bootstrap({self.format_duration(self.bootstrap_duration)}) "
f"+ other({self.format_duration(other)}); "
f"transfer_duration={self.format_duration(transfer_duration)}; "
f"queue_duration={self.format_duration(queue_duration)}, "
f"forward_duration={self.format_duration(forward_duration)}, "
f"start={self.decode_prealloc_queue_entry_time:.3f}"
)
else: else:
return "Unknown Time Stats" return "Unknown Time Stats"
...@@ -141,6 +171,8 @@ class SchedulerStats: ...@@ -141,6 +171,8 @@ class SchedulerStats:
num_decode_transfer_queue_reqs: int = 0 num_decode_transfer_queue_reqs: int = 0
kv_transfer_speed_gb_s: float = 0.0 kv_transfer_speed_gb_s: float = 0.0
kv_transfer_latency_ms: float = 0.0 kv_transfer_latency_ms: float = 0.0
kv_transfer_bootstrap_ms: float = 0.0
kv_transfer_alloc_ms: float = 0.0
# Utilization # Utilization
utilization: float = 0.0 utilization: float = 0.0
...@@ -297,6 +329,18 @@ class SchedulerMetricsCollector: ...@@ -297,6 +329,18 @@ class SchedulerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
) )
self.kv_transfer_bootstrap_ms = Gauge(
name="sglang:kv_transfer_bootstrap_ms",
documentation="The bootstrap time of the KV transfer in ms.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.kv_transfer_alloc_ms = Gauge(
name="sglang:kv_transfer_alloc_ms",
documentation="The allocation waiting time of the KV transfer in ms.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
# Utilization # Utilization
self.utilization = Gauge( self.utilization = Gauge(
...@@ -564,6 +608,8 @@ class SchedulerMetricsCollector: ...@@ -564,6 +608,8 @@ class SchedulerMetricsCollector:
) )
self._log_gauge(self.kv_transfer_speed_gb_s, stats.kv_transfer_speed_gb_s) self._log_gauge(self.kv_transfer_speed_gb_s, stats.kv_transfer_speed_gb_s)
self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms) self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms)
self._log_gauge(self.kv_transfer_bootstrap_ms, stats.kv_transfer_bootstrap_ms)
self._log_gauge(self.kv_transfer_alloc_ms, stats.kv_transfer_alloc_ms)
# Retract # Retract
self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs) self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs)
......
...@@ -552,6 +552,10 @@ class ServerArgs: ...@@ -552,6 +552,10 @@ class ServerArgs:
mm_max_concurrent_calls: int = 32 mm_max_concurrent_calls: int = 32
mm_per_request_timeout: float = 10.0 mm_per_request_timeout: float = 10.0
# For checkpoint decryption
decrypted_config_file: Optional[str] = None
decrypted_draft_config_file: Optional[str] = None
def __post_init__(self): def __post_init__(self):
""" """
Orchestrates the handling of various server arguments, ensuring proper configuration and validation. Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
...@@ -3551,6 +3555,20 @@ class ServerArgs: ...@@ -3551,6 +3555,20 @@ class ServerArgs:
help="The timeout for each multi-modal request in seconds.", help="The timeout for each multi-modal request in seconds.",
) )
# For checkpoint decryption
parser.add_argument(
"--decrypted-config-file",
type=str,
default=ServerArgs.decrypted_config_file,
help="The path of the decrypted config file.",
)
parser.add_argument(
"--decrypted-draft-config-file",
type=str,
default=ServerArgs.decrypted_draft_config_file,
help="The path of the decrypted draft config file.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment