Unverified Commit 58095cb0 authored by yinghui's avatar yinghui Committed by GitHub
Browse files

Add timing metrics for requests (#12646)


Co-authored-by: default avatarScott Lee <scottjlee@users.noreply.github.com>
parent fd3034da
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
import json import json
import logging import logging
import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
...@@ -84,10 +85,14 @@ class OpenAIServingBase(ABC): ...@@ -84,10 +85,14 @@ class OpenAIServingBase(ABC):
async def handle_request( async def handle_request(
self, request: OpenAIServingRequest, raw_request: Request self, request: OpenAIServingRequest, raw_request: Request
) -> Union[Any, StreamingResponse, ErrorResponse]: ) -> Union[Any, StreamingResponse, ErrorResponse]:
"""Handle the specific request type with common pattern""" """Handle the specific request type with common pattern
If you want to override this method, you should be careful to record the validation time.
"""
try: try:
# Validate request # Validate request
validation_start = time.perf_counter()
error_msg = self._validate_request(request) error_msg = self._validate_request(request)
validation_time = time.perf_counter() - validation_start
if error_msg: if error_msg:
return self.create_error_response(error_msg) return self.create_error_response(error_msg)
...@@ -95,6 +100,8 @@ class OpenAIServingBase(ABC): ...@@ -95,6 +100,8 @@ class OpenAIServingBase(ABC):
adapted_request, processed_request = self._convert_to_internal_request( adapted_request, processed_request = self._convert_to_internal_request(
request, raw_request request, raw_request
) )
if hasattr(adapted_request, "validation_time"):
adapted_request.validation_time = validation_time
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client # Note(Xinyuan): raw_request below is only used for detecting the connection of the client
if hasattr(request, "stream") and request.stream: if hasattr(request, "stream") and request.stream:
...@@ -157,6 +164,7 @@ class OpenAIServingBase(ABC): ...@@ -157,6 +164,7 @@ class OpenAIServingBase(ABC):
self, self,
request: OpenAIServingRequest, request: OpenAIServingRequest,
raw_request: Request = None, raw_request: Request = None,
validation_time: float = None,
) -> tuple[GenerateReqInput, OpenAIServingRequest]: ) -> tuple[GenerateReqInput, OpenAIServingRequest]:
"""Convert OpenAI request to internal format""" """Convert OpenAI request to internal format"""
pass pass
......
...@@ -80,6 +80,10 @@ class GrpcReqState: ...@@ -80,6 +80,10 @@ class GrpcReqState:
last_time: float = 0.0 last_time: float = 0.0
last_completion_tokens: int = 1 last_completion_tokens: int = 1
# perf_counter equivalents for accurate time calculations
finished_time_perf: float = 0.0
first_token_time_perf: float = 0.0
# Streaming state # Streaming state
stream_finished: bool = False stream_finished: bool = False
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
...@@ -536,6 +540,7 @@ class GrpcRequestManager: ...@@ -536,6 +540,7 @@ class GrpcRequestManager:
put_tasks = [] put_tasks = []
cleanup_tasks = [] cleanup_tasks = []
now = time.time() now = time.time()
now_perf_counter = time.perf_counter()
# Process each request in the batch # Process each request in the batch
for i, rid in enumerate(batch_out.rids): for i, rid in enumerate(batch_out.rids):
...@@ -552,6 +557,7 @@ class GrpcRequestManager: ...@@ -552,6 +557,7 @@ class GrpcRequestManager:
# Update metrics # Update metrics
if state.first_token_time == 0.0: if state.first_token_time == 0.0:
state.first_token_time = now state.first_token_time = now
state.first_token_time_perf = now_perf_counter
state.last_time = now state.last_time = now
# Extract output for this request # Extract output for this request
...@@ -650,6 +656,7 @@ class GrpcRequestManager: ...@@ -650,6 +656,7 @@ class GrpcRequestManager:
if output_data["finished"]: if output_data["finished"]:
state.finished = True state.finished = True
state.finished_time = now state.finished_time = now
state.finished_time_perf = now_perf_counter
state.stream_finished = True state.stream_finished = True
state.event.set() state.event.set()
...@@ -691,6 +698,7 @@ class GrpcRequestManager: ...@@ -691,6 +698,7 @@ class GrpcRequestManager:
# Mark as finished # Mark as finished
state.finished = True state.finished = True
state.finished_time = time.time() state.finished_time = time.time()
state.finished_time_perf = time.perf_counter()
state.event.set() state.event.set()
async def _handle_health_check_output(self, health_out: HealthCheckOutput): async def _handle_health_check_output(self, health_out: HealthCheckOutput):
...@@ -723,6 +731,7 @@ class GrpcRequestManager: ...@@ -723,6 +731,7 @@ class GrpcRequestManager:
# Mark as finished # Mark as finished
state.finished = True state.finished = True
state.finished_time = time.time() state.finished_time = time.time()
state.finished_time_perf = time.perf_counter()
state.event.set() state.event.set()
async def _handle_abort_req(self, recv_obj: AbortReq): async def _handle_abort_req(self, recv_obj: AbortReq):
......
...@@ -277,6 +277,10 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -277,6 +277,10 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
placeholder_tokens_val=None, placeholder_tokens_val=None,
retraction_counts=recv_obj.retraction_counts, retraction_counts=recv_obj.retraction_counts,
token_steps=recv_obj.token_steps, token_steps=recv_obj.token_steps,
queue_time=recv_obj.queue_time,
forward_entry_time=recv_obj.forward_entry_time,
prefill_delay=recv_obj.prefill_delay,
prefill_latency=recv_obj.prefill_latency,
) )
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
...@@ -291,6 +295,10 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -291,6 +295,10 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
queue_time=recv_obj.queue_time,
forward_entry_time=recv_obj.forward_entry_time,
prefill_delay=recv_obj.prefill_delay,
prefill_latency=recv_obj.prefill_latency,
) )
def handle_freeze_gc_req(self, recv_req: FreezeGCReq): def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
......
...@@ -61,6 +61,55 @@ class BaseBatchReq(ABC): ...@@ -61,6 +61,55 @@ class BaseBatchReq(ABC):
return self.rids return self.rids
@dataclass
class RequestTimingMetricsMixin:
"""
Mixin class containing common request-level timing metrics.
This class consolidates the timing metrics that are shared across all batch output types
to avoid code duplication and ensure consistency.
"""
# Queue duration: time spent waiting in queue before request is scheduled.
queue_time: Optional[List[Optional[float]]]
# Forward entry time: timestamp when the request enters the forward pass stage.
# This corresponds to `forward_entry_time` in TimeStats.
# In different modes:
# - Unified/PD-colocate: timestamp when forward computation begins (covers prefill + decode)
# - Prefill instance (P): timestamp when prefill forward pass begins
# - Decode instance (D): timestamp when decode forward pass begins
# Note: This is NOT the same as prefill_start_time. There may be a delay between
# forward_entry_time and prefill_start_time (see prefill_delay).
forward_entry_time: Optional[List[Optional[float]]]
# Prefill delay: time spent waiting between forward entry and prefill start.
# Calculated as: prefill_start_time - forward_entry_time
# This represents the delay between when the request enters the forward stage
# and when prefill computation actually begins.
prefill_delay: Optional[List[Optional[float]]]
# Prefill latency: time spent during prefill computation.
# Calculated as: prefill_end_time - prefill_start_time
prefill_latency: Optional[List[Optional[float]]]
@dataclass
class SpeculativeDecodingMetricsMixin:
"""
Mixin class containing speculative decoding metrics.
This class consolidates speculative decoding metrics that are shared across
batch output types that support speculative decoding to avoid code duplication.
"""
# Verify count: number of verification forward passes
spec_verify_ct: List[int]
# Accepted tokens: Number of accepted tokens during speculative decoding
spec_accepted_tokens: List[int]
# Parameters for a session # Parameters for a session
@dataclass @dataclass
class SessionParams: class SessionParams:
...@@ -148,6 +197,9 @@ class GenerateReqInput(BaseReq): ...@@ -148,6 +197,9 @@ class GenerateReqInput(BaseReq):
bootstrap_room: Optional[Union[List[int], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None
bootstrap_pair_key: Optional[Union[List[str], str]] = None bootstrap_pair_key: Optional[Union[List[str], str]] = None
# Validation step duration
validation_time: Optional[float] = None
# For data parallel rank routing # For data parallel rank routing
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
...@@ -564,6 +616,7 @@ class GenerateReqInput(BaseReq): ...@@ -564,6 +616,7 @@ class GenerateReqInput(BaseReq):
if self.bootstrap_pair_key is not None if self.bootstrap_pair_key is not None
else None else None
), ),
validation_time=self.validation_time,
data_parallel_rank=( data_parallel_rank=(
self.data_parallel_rank if self.data_parallel_rank is not None else None self.data_parallel_rank if self.data_parallel_rank is not None else None
), ),
...@@ -684,6 +737,8 @@ class EmbeddingReqInput(BaseReq): ...@@ -684,6 +737,8 @@ class EmbeddingReqInput(BaseReq):
log_metrics: bool = True log_metrics: bool = True
# The modalities of the image data [image, multi-images, video] # The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
# Validation step duration
validation_time: Optional[float] = None
# For cross-encoder requests # For cross-encoder requests
is_cross_encoder_request: bool = False is_cross_encoder_request: bool = False
# Priority for the request # Priority for the request
...@@ -774,6 +829,7 @@ class EmbeddingReqInput(BaseReq): ...@@ -774,6 +829,7 @@ class EmbeddingReqInput(BaseReq):
video_data=self.video_data[i] if self.video_data is not None else None, video_data=self.video_data[i] if self.video_data is not None else None,
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],
rid=self.rid[i], rid=self.rid[i],
validation_time=self.validation_time,
dimensions=self.dimensions, dimensions=self.dimensions,
http_worker_ipc=self.http_worker_ipc, http_worker_ipc=self.http_worker_ipc,
) )
...@@ -815,7 +871,9 @@ class BatchTokenizedEmbeddingReqInput(BaseBatchReq): ...@@ -815,7 +871,9 @@ class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
@dataclass @dataclass
class BatchTokenIDOutput(BaseBatchReq): class BatchTokenIDOutput(
BaseBatchReq, RequestTimingMetricsMixin, SpeculativeDecodingMetricsMixin
):
# The finish reason # The finish reason
finished_reasons: List[BaseFinishReason] finished_reasons: List[BaseFinishReason]
# For incremental decoding # For incremental decoding
...@@ -833,8 +891,6 @@ class BatchTokenIDOutput(BaseBatchReq): ...@@ -833,8 +891,6 @@ class BatchTokenIDOutput(BaseBatchReq):
prompt_tokens: List[int] prompt_tokens: List[int]
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
spec_verify_ct: List[int]
spec_accepted_tokens: List[int]
# Logprobs # Logprobs
input_token_logprobs_val: List[float] input_token_logprobs_val: List[float]
...@@ -868,7 +924,7 @@ class BatchTokenIDOutput(BaseBatchReq): ...@@ -868,7 +924,7 @@ class BatchTokenIDOutput(BaseBatchReq):
@dataclass @dataclass
class BatchMultimodalDecodeReq(BaseBatchReq): class BatchMultimodalDecodeReq(BaseBatchReq, RequestTimingMetricsMixin):
decoded_ids: List[int] decoded_ids: List[int]
input_token_logprobs_val: List[float] input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int] input_token_logprobs_idx: List[int]
...@@ -900,7 +956,9 @@ class BatchMultimodalDecodeReq(BaseBatchReq): ...@@ -900,7 +956,9 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
@dataclass @dataclass
class BatchStrOutput(BaseBatchReq): class BatchStrOutput(
BaseBatchReq, RequestTimingMetricsMixin, SpeculativeDecodingMetricsMixin
):
# The finish reason # The finish reason
finished_reasons: List[dict] finished_reasons: List[dict]
# The output decoded strings # The output decoded strings
...@@ -912,8 +970,6 @@ class BatchStrOutput(BaseBatchReq): ...@@ -912,8 +970,6 @@ class BatchStrOutput(BaseBatchReq):
prompt_tokens: List[int] prompt_tokens: List[int]
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
spec_verify_ct: List[int]
spec_accepted_tokens: List[int]
# Logprobs # Logprobs
input_token_logprobs_val: List[float] input_token_logprobs_val: List[float]
...@@ -947,7 +1003,7 @@ class BatchStrOutput(BaseBatchReq): ...@@ -947,7 +1003,7 @@ class BatchStrOutput(BaseBatchReq):
@dataclass @dataclass
class BatchMultimodalOutput(BaseBatchReq): class BatchMultimodalOutput(BaseBatchReq, RequestTimingMetricsMixin):
# The finish reason # The finish reason
finished_reasons: List[dict] finished_reasons: List[dict]
decoded_ids: List[List[int]] decoded_ids: List[List[int]]
...@@ -972,7 +1028,7 @@ class BatchMultimodalOutput(BaseBatchReq): ...@@ -972,7 +1028,7 @@ class BatchMultimodalOutput(BaseBatchReq):
@dataclass @dataclass
class BatchEmbeddingOutput(BaseBatchReq): class BatchEmbeddingOutput(BaseBatchReq, RequestTimingMetricsMixin):
# The finish reason # The finish reason
finished_reasons: List[BaseFinishReason] finished_reasons: List[BaseFinishReason]
# The output embedding # The output embedding
......
...@@ -91,6 +91,26 @@ def _handle_output_by_index(output, i): ...@@ -91,6 +91,26 @@ def _handle_output_by_index(output, i):
if isinstance(output, BatchTokenIDOutput): if isinstance(output, BatchTokenIDOutput):
new_output = BatchTokenIDOutput( new_output = BatchTokenIDOutput(
rids=[output.rids[i]], rids=[output.rids[i]],
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
spec_accepted_tokens=(
[output.spec_accepted_tokens[i]]
if len(output.spec_accepted_tokens) > i
else None
),
queue_time=[output.queue_time[i]] if len(output.queue_time) > i else None,
forward_entry_time=(
[output.forward_entry_time[i]]
if len(output.forward_entry_time) > i
else None
),
prefill_delay=(
[output.prefill_delay[i]] if len(output.prefill_delay) > i else None
),
prefill_latency=(
[output.prefill_latency[i]] if len(output.prefill_latency) > i else None
),
finished_reasons=( finished_reasons=(
[output.finished_reasons[i]] [output.finished_reasons[i]]
if len(output.finished_reasons) > i if len(output.finished_reasons) > i
...@@ -132,9 +152,6 @@ def _handle_output_by_index(output, i): ...@@ -132,9 +152,6 @@ def _handle_output_by_index(output, i):
cached_tokens=( cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
), ),
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
input_token_logprobs_val=( input_token_logprobs_val=(
[output.input_token_logprobs_val[i]] [output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val if output.input_token_logprobs_val
...@@ -230,6 +247,26 @@ def _handle_output_by_index(output, i): ...@@ -230,6 +247,26 @@ def _handle_output_by_index(output, i):
elif isinstance(output, BatchStrOutput): elif isinstance(output, BatchStrOutput):
new_output = BatchStrOutput( new_output = BatchStrOutput(
rids=[output.rids[i]], rids=[output.rids[i]],
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
spec_accepted_tokens=(
[output.spec_accepted_tokens[i]]
if len(output.spec_accepted_tokens) > i
else None
),
queue_time=[output.queue_time[i]] if len(output.queue_time) > i else None,
forward_entry_time=(
[output.forward_entry_time[i]]
if len(output.forward_entry_time) > i
else None
),
prefill_delay=(
[output.prefill_delay[i]] if len(output.prefill_delay) > i else None
),
prefill_latency=(
[output.prefill_latency[i]] if len(output.prefill_latency) > i else None
),
finished_reasons=( finished_reasons=(
[output.finished_reasons[i]] [output.finished_reasons[i]]
if len(output.finished_reasons) > i if len(output.finished_reasons) > i
...@@ -254,14 +291,6 @@ def _handle_output_by_index(output, i): ...@@ -254,14 +291,6 @@ def _handle_output_by_index(output, i):
cached_tokens=( cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
), ),
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
spec_accepted_tokens=(
[output.spec_accepted_tokens[i]]
if len(output.spec_accepted_tokens) > i
else None
),
input_token_logprobs_val=( input_token_logprobs_val=(
[output.input_token_logprobs_val[i]] [output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val if output.input_token_logprobs_val
......
...@@ -152,6 +152,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache ...@@ -152,6 +152,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
...@@ -1952,6 +1953,12 @@ class Scheduler( ...@@ -1952,6 +1953,12 @@ class Scheduler(
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s") logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
time.sleep(self.forward_sleep_time) time.sleep(self.forward_sleep_time)
# Capture prefill start time for EXTEND mode
if batch.forward_mode == ForwardMode.EXTEND:
current_time = time.perf_counter()
for req in batch.reqs:
req.time_stats.prefill_start_time = current_time
# Run forward # Run forward
if self.is_generation: if self.is_generation:
batch_or_worker_batch = batch batch_or_worker_batch = batch
...@@ -2045,11 +2052,18 @@ class Scheduler( ...@@ -2045,11 +2052,18 @@ class Scheduler(
batch_result.extend_logprob_start_len_per_req = ( batch_result.extend_logprob_start_len_per_req = (
extend_logprob_start_len_per_req extend_logprob_start_len_per_req
) )
return batch_result ret = batch_result
else: # embedding or reward model else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(embeddings=embeddings) ret = EmbeddingBatchResult(embeddings=embeddings)
# Capture prefill end time for EXTEND mode
if batch.forward_mode == ForwardMode.EXTEND:
current_time = time.perf_counter()
for req in batch.reqs:
req.time_stats.prefill_end_time = current_time
return ret return ret
def launch_batch_sample_if_needed( def launch_batch_sample_if_needed(
......
...@@ -275,6 +275,7 @@ class SchedulerOutputProcessorMixin: ...@@ -275,6 +275,7 @@ class SchedulerOutputProcessorMixin:
next_token_ids[i * stride : i * stride + accept_lens[i]] next_token_ids[i * stride : i * stride + accept_lens[i]]
) )
req.spec_verify_ct += 1 req.spec_verify_ct += 1
req.spec_accepted_tokens += accept_lens[i] - 1
return predict_tokens return predict_tokens
...@@ -760,6 +761,11 @@ class SchedulerOutputProcessorMixin: ...@@ -760,6 +761,11 @@ class SchedulerOutputProcessorMixin:
retraction_counts = [] retraction_counts = []
output_hidden_states = None output_hidden_states = None
queue_times = []
forward_entry_times = []
prefill_delays = []
prefill_latencies = []
if return_logprob: if return_logprob:
input_token_logprobs_val = [] input_token_logprobs_val = []
input_token_logprobs_idx = [] input_token_logprobs_idx = []
...@@ -860,6 +866,28 @@ class SchedulerOutputProcessorMixin: ...@@ -860,6 +866,28 @@ class SchedulerOutputProcessorMixin:
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
retraction_counts.append(req.retraction_count) retraction_counts.append(req.retraction_count)
queue_times.append(req.time_stats.get_queueing_time())
forward_entry_times.append(req.time_stats.forward_entry_time)
if req.time_stats.prefill_start_time > 0.0:
prefill_delays.append(
req.time_stats.prefill_start_time
- req.time_stats.forward_entry_time
)
else:
prefill_delays.append(None)
if (
req.time_stats.prefill_start_time > 0.0
and req.time_stats.prefill_end_time > 0.0
):
prefill_latencies.append(
req.time_stats.prefill_end_time
- req.time_stats.prefill_start_time
)
else:
prefill_latencies.append(None)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct) spec_verify_ct.append(req.spec_verify_ct)
spec_accepted_tokens.append(req.spec_accepted_tokens) spec_accepted_tokens.append(req.spec_accepted_tokens)
...@@ -951,31 +979,35 @@ class SchedulerOutputProcessorMixin: ...@@ -951,31 +979,35 @@ class SchedulerOutputProcessorMixin:
self.send_to_detokenizer.send_output( self.send_to_detokenizer.send_output(
BatchTokenIDOutput( BatchTokenIDOutput(
finished_reasons, spec_verify_ct=spec_verify_ct,
decoded_texts, spec_accepted_tokens=spec_accepted_tokens,
decode_ids_list, queue_time=queue_times,
read_offsets, forward_entry_time=forward_entry_times,
output_ids, prefill_delay=prefill_delays,
skip_special_tokens, prefill_latency=prefill_latencies,
spaces_between_special_tokens, finished_reasons=finished_reasons,
no_stop_trim, decoded_texts=decoded_texts,
prompt_tokens, decode_ids=decode_ids_list,
completion_tokens, read_offsets=read_offsets,
cached_tokens, output_ids=output_ids,
spec_verify_ct, skip_special_tokens=skip_special_tokens,
spec_accepted_tokens, spaces_between_special_tokens=spaces_between_special_tokens,
input_token_logprobs_val, no_stop_trim=no_stop_trim,
input_token_logprobs_idx, prompt_tokens=prompt_tokens,
output_token_logprobs_val, completion_tokens=completion_tokens,
output_token_logprobs_idx, cached_tokens=cached_tokens,
input_top_logprobs_val, input_token_logprobs_val=input_token_logprobs_val,
input_top_logprobs_idx, input_token_logprobs_idx=input_token_logprobs_idx,
output_top_logprobs_val, output_token_logprobs_val=output_token_logprobs_val,
output_top_logprobs_idx, output_token_logprobs_idx=output_token_logprobs_idx,
input_token_ids_logprobs_val, input_top_logprobs_val=input_top_logprobs_val,
input_token_ids_logprobs_idx, input_top_logprobs_idx=input_top_logprobs_idx,
output_token_ids_logprobs_val, output_top_logprobs_val=output_top_logprobs_val,
output_token_ids_logprobs_idx, output_top_logprobs_idx=output_top_logprobs_idx,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
output_token_ids_logprobs_val=output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=output_token_ids_logprobs_idx,
output_token_entropy_val=None, output_token_entropy_val=None,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
rids=rids, rids=rids,
...@@ -994,6 +1026,10 @@ class SchedulerOutputProcessorMixin: ...@@ -994,6 +1026,10 @@ class SchedulerOutputProcessorMixin:
embeddings = [] embeddings = []
prompt_tokens = [] prompt_tokens = []
cached_tokens = [] cached_tokens = []
queue_times = []
forward_entry_times = []
prefill_delays = []
prefill_latencies = []
retraction_counts = [] retraction_counts = []
for req in reqs: for req in reqs:
if req.finished(): if req.finished():
...@@ -1003,17 +1039,43 @@ class SchedulerOutputProcessorMixin: ...@@ -1003,17 +1039,43 @@ class SchedulerOutputProcessorMixin:
embeddings.append(req.embedding) embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids)) prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
queue_times.append(req.time_stats.get_queueing_time())
forward_entry_times.append(req.time_stats.forward_entry_time)
if req.time_stats.prefill_start_time > 0.0:
prefill_delays.append(
req.time_stats.prefill_start_time
- req.time_stats.forward_entry_time
)
else:
prefill_delays.append(None)
if (
req.time_stats.prefill_start_time > 0.0
and req.time_stats.prefill_end_time > 0.0
):
prefill_latencies.append(
req.time_stats.prefill_end_time
- req.time_stats.prefill_start_time
)
else:
prefill_latencies.append(None)
retraction_counts.append(req.retraction_count) retraction_counts.append(req.retraction_count)
self.send_to_detokenizer.send_output( self.send_to_detokenizer.send_output(
BatchEmbeddingOutput( BatchEmbeddingOutput(
finished_reasons, queue_time=queue_times,
embeddings, forward_entry_time=forward_entry_times,
prompt_tokens, prefill_delay=prefill_delays,
cached_tokens, prefill_latency=prefill_latencies,
rids=rids, finished_reasons=finished_reasons,
embeddings=embeddings,
prompt_tokens=prompt_tokens,
cached_tokens=cached_tokens,
http_worker_ipcs=http_worker_ipcs, http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
retraction_counts=retraction_counts, retraction_counts=retraction_counts,
rids=rids,
) )
) )
...@@ -136,6 +136,13 @@ class ReqState: ...@@ -136,6 +136,13 @@ class ReqState:
last_time: float = 0.0 last_time: float = 0.0
last_completion_tokens: int = 1 last_completion_tokens: int = 1
# perf_counter equivalents for accurate time calculations
finished_time_perf: float = 0.0
first_token_time_perf: float = 0.0
request_scheduled_ts: float = 0.0
response_sent_ts: float = 0.0
# For streaming output # For streaming output
last_output_offset: int = 0 last_output_offset: int = 0
...@@ -911,6 +918,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -911,6 +918,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
state.request_scheduled_ts = time.time()
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
trace_slice_end( trace_slice_end(
RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True
...@@ -968,6 +976,11 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -968,6 +976,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
# For non-streaming cases, response has not been sent yet (`response_sent_ts` has not been set yet).
# Record response sent time right before we log finished results and metrics.
if not state.response_sent_ts:
state.response_sent_ts = time.time()
out["meta_info"]["response_sent_ts"] = state.response_sent_ts
if self.log_requests: if self.log_requests:
max_length, skip_names, out_skip_names = self.log_request_metadata max_length, skip_names, out_skip_names = self.log_request_metadata
if self.model_config.is_multimodal_gen: if self.model_config.is_multimodal_gen:
...@@ -1011,6 +1024,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1011,6 +1024,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.event.clear() state.event.clear()
if obj.stream: if obj.stream:
# Record response sent time right before we send response.
if not state.response_sent_ts:
state.response_sent_ts = time.time()
out["meta_info"]["response_sent_ts"] = state.response_sent_ts
yield out yield out
else: else:
if ( if (
...@@ -1418,6 +1435,27 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1418,6 +1435,27 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"total_retractions": recv_obj.retraction_counts[i], "total_retractions": recv_obj.retraction_counts[i],
} }
if (
hasattr(recv_obj, "queue_time")
and recv_obj.queue_time
and recv_obj.queue_time[i] is not None
):
meta_info["queue_time"] = recv_obj.queue_time[i]
if (
hasattr(recv_obj, "prefill_delay")
and recv_obj.prefill_delay
and recv_obj.prefill_delay[i] is not None
):
meta_info["prefill_delay"] = recv_obj.prefill_delay[i]
if (
hasattr(recv_obj, "prefill_latency")
and recv_obj.prefill_latency
and recv_obj.prefill_latency[i] is not None
):
meta_info["prefill_latency"] = recv_obj.prefill_latency[i]
if getattr(state.obj, "return_logprob", False): if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style( self.convert_logprob_style(
meta_info, meta_info,
...@@ -1483,8 +1521,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1483,8 +1521,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if self.server_args.speculative_algorithm: if self.server_args.speculative_algorithm:
self._calculate_spec_decoding_metrics(meta_info, recv_obj, i) self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
state.finished_time = time.time() state.finished_time = time.time()
state.finished_time_perf = time.perf_counter()
meta_info["e2e_latency"] = state.finished_time - state.created_time meta_info["e2e_latency"] = state.finished_time - state.created_time
# Calculate timing metrics
self._calculate_timing_metrics(meta_info, state, recv_obj, i)
trace_req_finish(rid, ts=int(state.finished_time * 1e9)) trace_req_finish(rid, ts=int(state.finished_time * 1e9))
del self.rid_to_state[rid] del self.rid_to_state[rid]
...@@ -1687,6 +1729,57 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1687,6 +1729,57 @@ class TokenizerManager(TokenizerCommunicatorMixin):
recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i] recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
) )
def _calculate_timing_metrics(
self,
meta_info: Dict[str, Any],
state: ReqState,
recv_obj: Union[
BatchStrOutput,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchTokenIDOutput,
],
i: int,
) -> None:
"""Calculate request-level timing metrics, such as inference time, decode throughput, and time per token."""
# Request timing timestamps.
if state.created_time > 0:
meta_info["request_received_ts"] = state.created_time
if state.request_scheduled_ts > 0:
meta_info["request_scheduled_ts"] = state.request_scheduled_ts
# For embeddings, there's no separate prefill phase, so omit `prefill_finished_ts`.
if (
not isinstance(recv_obj, BatchEmbeddingOutput)
and state.first_token_time > 0
):
meta_info["prefill_finished_ts"] = state.first_token_time
if state.response_sent_ts > 0:
meta_info["response_sent_ts"] = state.response_sent_ts
if state.finished_time > 0:
meta_info["decode_finished_ts"] = state.finished_time
# Inference time calculation.
if (
hasattr(recv_obj, "forward_entry_time")
and recv_obj.forward_entry_time
and recv_obj.forward_entry_time[i] is not None
and state.finished_time_perf > 0.0
):
forward_time = state.finished_time_perf - recv_obj.forward_entry_time[i]
meta_info["forward_time"] = forward_time
# Decode throughput, time per token calculation. Only calculated if TTFT is available.
if (
state.first_token_time_perf > 0.0
and state.finished_time_perf > 0.0
and not isinstance(recv_obj, BatchEmbeddingOutput)
and recv_obj.completion_tokens[i] > 0
):
decode_time = state.finished_time_perf - state.first_token_time_perf
completion_tokens = recv_obj.completion_tokens[i]
meta_info["decode_throughput"] = completion_tokens / decode_time
meta_info["time_per_token"] = decode_time / completion_tokens
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int): def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
completion_tokens = ( completion_tokens = (
recv_obj.completion_tokens[i] recv_obj.completion_tokens[i]
...@@ -1705,6 +1798,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1705,6 +1798,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
and self.disaggregation_mode != DisaggregationMode.PREFILL and self.disaggregation_mode != DisaggregationMode.PREFILL
): ):
state.first_token_time = state.last_time = time.time() state.first_token_time = state.last_time = time.time()
state.first_token_time_perf = time.perf_counter()
state.last_completion_tokens = completion_tokens state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token( self.metrics_collector.observe_time_to_first_token(
labels, state.first_token_time - state.created_time labels, state.first_token_time - state.created_time
......
...@@ -46,6 +46,8 @@ class TimeStats: ...@@ -46,6 +46,8 @@ class TimeStats:
# TODO: correct set them # TODO: correct set them
bootstrap_duration: float = 0.0 bootstrap_duration: float = 0.0
alloc_waiting_duration: float = 0.0 alloc_waiting_duration: float = 0.0
prefill_start_time: float = 0.0
prefill_end_time: float = 0.0
def get_queueing_time(self) -> float: def get_queueing_time(self) -> float:
return self.forward_entry_time - self.wait_queue_entry_time return self.forward_entry_time - self.wait_queue_entry_time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment