"examples/vscode:/vscode.git/clone" did not exist on "978dec9014667f394ab11f79dfc54a9c9a7290c7"
Unverified Commit 58095cb0 authored by yinghui's avatar yinghui Committed by GitHub
Browse files

Add timing metrics for requests (#12646)


Co-authored-by: default avatarScott Lee <scottjlee@users.noreply.github.com>
parent fd3034da
......@@ -2,6 +2,7 @@ from __future__ import annotations
import json
import logging
import time
import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
......@@ -84,10 +85,14 @@ class OpenAIServingBase(ABC):
async def handle_request(
self, request: OpenAIServingRequest, raw_request: Request
) -> Union[Any, StreamingResponse, ErrorResponse]:
"""Handle the specific request type with common pattern"""
"""Handle the specific request type with common pattern
If you want to override this method, you should be careful to record the validation time.
"""
try:
# Validate request
validation_start = time.perf_counter()
error_msg = self._validate_request(request)
validation_time = time.perf_counter() - validation_start
if error_msg:
return self.create_error_response(error_msg)
......@@ -95,6 +100,8 @@ class OpenAIServingBase(ABC):
adapted_request, processed_request = self._convert_to_internal_request(
request, raw_request
)
if hasattr(adapted_request, "validation_time"):
adapted_request.validation_time = validation_time
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
if hasattr(request, "stream") and request.stream:
......@@ -157,6 +164,7 @@ class OpenAIServingBase(ABC):
self,
request: OpenAIServingRequest,
raw_request: Request = None,
validation_time: float = None,
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
"""Convert OpenAI request to internal format"""
pass
......
......@@ -80,6 +80,10 @@ class GrpcReqState:
last_time: float = 0.0
last_completion_tokens: int = 1
# perf_counter equivalents for accurate time calculations
finished_time_perf: float = 0.0
first_token_time_perf: float = 0.0
# Streaming state
stream_finished: bool = False
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
......@@ -536,6 +540,7 @@ class GrpcRequestManager:
put_tasks = []
cleanup_tasks = []
now = time.time()
now_perf_counter = time.perf_counter()
# Process each request in the batch
for i, rid in enumerate(batch_out.rids):
......@@ -552,6 +557,7 @@ class GrpcRequestManager:
# Update metrics
if state.first_token_time == 0.0:
state.first_token_time = now
state.first_token_time_perf = now_perf_counter
state.last_time = now
# Extract output for this request
......@@ -650,6 +656,7 @@ class GrpcRequestManager:
if output_data["finished"]:
state.finished = True
state.finished_time = now
state.finished_time_perf = now_perf_counter
state.stream_finished = True
state.event.set()
......@@ -691,6 +698,7 @@ class GrpcRequestManager:
# Mark as finished
state.finished = True
state.finished_time = time.time()
state.finished_time_perf = time.perf_counter()
state.event.set()
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
......@@ -723,6 +731,7 @@ class GrpcRequestManager:
# Mark as finished
state.finished = True
state.finished_time = time.time()
state.finished_time_perf = time.perf_counter()
state.event.set()
async def _handle_abort_req(self, recv_obj: AbortReq):
......
......@@ -277,6 +277,10 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
placeholder_tokens_val=None,
retraction_counts=recv_obj.retraction_counts,
token_steps=recv_obj.token_steps,
queue_time=recv_obj.queue_time,
forward_entry_time=recv_obj.forward_entry_time,
prefill_delay=recv_obj.prefill_delay,
prefill_latency=recv_obj.prefill_latency,
)
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
......@@ -291,6 +295,10 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
cached_tokens=recv_obj.cached_tokens,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
queue_time=recv_obj.queue_time,
forward_entry_time=recv_obj.forward_entry_time,
prefill_delay=recv_obj.prefill_delay,
prefill_latency=recv_obj.prefill_latency,
)
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
......
......@@ -61,6 +61,55 @@ class BaseBatchReq(ABC):
return self.rids
@dataclass
class RequestTimingMetricsMixin:
"""
Mixin class containing common request-level timing metrics.
This class consolidates the timing metrics that are shared across all batch output types
to avoid code duplication and ensure consistency.
"""
# Queue duration: time spent waiting in queue before request is scheduled.
queue_time: Optional[List[Optional[float]]]
# Forward entry time: timestamp when the request enters the forward pass stage.
# This corresponds to `forward_entry_time` in TimeStats.
# In different modes:
# - Unified/PD-colocate: timestamp when forward computation begins (covers prefill + decode)
# - Prefill instance (P): timestamp when prefill forward pass begins
# - Decode instance (D): timestamp when decode forward pass begins
# Note: This is NOT the same as prefill_start_time. There may be a delay between
# forward_entry_time and prefill_start_time (see prefill_delay).
forward_entry_time: Optional[List[Optional[float]]]
# Prefill delay: time spent waiting between forward entry and prefill start.
# Calculated as: prefill_start_time - forward_entry_time
# This represents the delay between when the request enters the forward stage
# and when prefill computation actually begins.
prefill_delay: Optional[List[Optional[float]]]
# Prefill latency: time spent during prefill computation.
# Calculated as: prefill_end_time - prefill_start_time
prefill_latency: Optional[List[Optional[float]]]
@dataclass
class SpeculativeDecodingMetricsMixin:
"""
Mixin class containing speculative decoding metrics.
This class consolidates speculative decoding metrics that are shared across
batch output types that support speculative decoding to avoid code duplication.
"""
# Verify count: number of verification forward passes
spec_verify_ct: List[int]
# Accepted tokens: Number of accepted tokens during speculative decoding
spec_accepted_tokens: List[int]
# Parameters for a session
@dataclass
class SessionParams:
......@@ -148,6 +197,9 @@ class GenerateReqInput(BaseReq):
bootstrap_room: Optional[Union[List[int], int]] = None
bootstrap_pair_key: Optional[Union[List[str], str]] = None
# Validation step duration
validation_time: Optional[float] = None
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
......@@ -564,6 +616,7 @@ class GenerateReqInput(BaseReq):
if self.bootstrap_pair_key is not None
else None
),
validation_time=self.validation_time,
data_parallel_rank=(
self.data_parallel_rank if self.data_parallel_rank is not None else None
),
......@@ -684,6 +737,8 @@ class EmbeddingReqInput(BaseReq):
log_metrics: bool = True
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
# Validation step duration
validation_time: Optional[float] = None
# For cross-encoder requests
is_cross_encoder_request: bool = False
# Priority for the request
......@@ -774,6 +829,7 @@ class EmbeddingReqInput(BaseReq):
video_data=self.video_data[i] if self.video_data is not None else None,
sampling_params=self.sampling_params[i],
rid=self.rid[i],
validation_time=self.validation_time,
dimensions=self.dimensions,
http_worker_ipc=self.http_worker_ipc,
)
......@@ -815,7 +871,9 @@ class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
@dataclass
class BatchTokenIDOutput(BaseBatchReq):
class BatchTokenIDOutput(
BaseBatchReq, RequestTimingMetricsMixin, SpeculativeDecodingMetricsMixin
):
# The finish reason
finished_reasons: List[BaseFinishReason]
# For incremental decoding
......@@ -833,8 +891,6 @@ class BatchTokenIDOutput(BaseBatchReq):
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
spec_accepted_tokens: List[int]
# Logprobs
input_token_logprobs_val: List[float]
......@@ -868,7 +924,7 @@ class BatchTokenIDOutput(BaseBatchReq):
@dataclass
class BatchMultimodalDecodeReq(BaseBatchReq):
class BatchMultimodalDecodeReq(BaseBatchReq, RequestTimingMetricsMixin):
decoded_ids: List[int]
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
......@@ -900,7 +956,9 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
@dataclass
class BatchStrOutput(BaseBatchReq):
class BatchStrOutput(
BaseBatchReq, RequestTimingMetricsMixin, SpeculativeDecodingMetricsMixin
):
# The finish reason
finished_reasons: List[dict]
# The output decoded strings
......@@ -912,8 +970,6 @@ class BatchStrOutput(BaseBatchReq):
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
spec_verify_ct: List[int]
spec_accepted_tokens: List[int]
# Logprobs
input_token_logprobs_val: List[float]
......@@ -947,7 +1003,7 @@ class BatchStrOutput(BaseBatchReq):
@dataclass
class BatchMultimodalOutput(BaseBatchReq):
class BatchMultimodalOutput(BaseBatchReq, RequestTimingMetricsMixin):
# The finish reason
finished_reasons: List[dict]
decoded_ids: List[List[int]]
......@@ -972,7 +1028,7 @@ class BatchMultimodalOutput(BaseBatchReq):
@dataclass
class BatchEmbeddingOutput(BaseBatchReq):
class BatchEmbeddingOutput(BaseBatchReq, RequestTimingMetricsMixin):
# The finish reason
finished_reasons: List[BaseFinishReason]
# The output embedding
......
......@@ -91,6 +91,26 @@ def _handle_output_by_index(output, i):
if isinstance(output, BatchTokenIDOutput):
new_output = BatchTokenIDOutput(
rids=[output.rids[i]],
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
spec_accepted_tokens=(
[output.spec_accepted_tokens[i]]
if len(output.spec_accepted_tokens) > i
else None
),
queue_time=[output.queue_time[i]] if len(output.queue_time) > i else None,
forward_entry_time=(
[output.forward_entry_time[i]]
if len(output.forward_entry_time) > i
else None
),
prefill_delay=(
[output.prefill_delay[i]] if len(output.prefill_delay) > i else None
),
prefill_latency=(
[output.prefill_latency[i]] if len(output.prefill_latency) > i else None
),
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
......@@ -132,9 +152,6 @@ def _handle_output_by_index(output, i):
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
......@@ -230,6 +247,26 @@ def _handle_output_by_index(output, i):
elif isinstance(output, BatchStrOutput):
new_output = BatchStrOutput(
rids=[output.rids[i]],
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
spec_accepted_tokens=(
[output.spec_accepted_tokens[i]]
if len(output.spec_accepted_tokens) > i
else None
),
queue_time=[output.queue_time[i]] if len(output.queue_time) > i else None,
forward_entry_time=(
[output.forward_entry_time[i]]
if len(output.forward_entry_time) > i
else None
),
prefill_delay=(
[output.prefill_delay[i]] if len(output.prefill_delay) > i else None
),
prefill_latency=(
[output.prefill_latency[i]] if len(output.prefill_latency) > i else None
),
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
......@@ -254,14 +291,6 @@ def _handle_output_by_index(output, i):
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
spec_accepted_tokens=(
[output.spec_accepted_tokens[i]]
if len(output.spec_accepted_tokens) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
......
......@@ -152,6 +152,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
......@@ -1952,6 +1953,12 @@ class Scheduler(
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
time.sleep(self.forward_sleep_time)
# Capture prefill start time for EXTEND mode
if batch.forward_mode == ForwardMode.EXTEND:
current_time = time.perf_counter()
for req in batch.reqs:
req.time_stats.prefill_start_time = current_time
# Run forward
if self.is_generation:
batch_or_worker_batch = batch
......@@ -2045,11 +2052,18 @@ class Scheduler(
batch_result.extend_logprob_start_len_per_req = (
extend_logprob_start_len_per_req
)
return batch_result
ret = batch_result
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(embeddings=embeddings)
# Capture prefill end time for EXTEND mode
if batch.forward_mode == ForwardMode.EXTEND:
current_time = time.perf_counter()
for req in batch.reqs:
req.time_stats.prefill_end_time = current_time
return ret
def launch_batch_sample_if_needed(
......
......@@ -275,6 +275,7 @@ class SchedulerOutputProcessorMixin:
next_token_ids[i * stride : i * stride + accept_lens[i]]
)
req.spec_verify_ct += 1
req.spec_accepted_tokens += accept_lens[i] - 1
return predict_tokens
......@@ -760,6 +761,11 @@ class SchedulerOutputProcessorMixin:
retraction_counts = []
output_hidden_states = None
queue_times = []
forward_entry_times = []
prefill_delays = []
prefill_latencies = []
if return_logprob:
input_token_logprobs_val = []
input_token_logprobs_idx = []
......@@ -860,6 +866,28 @@ class SchedulerOutputProcessorMixin:
cached_tokens.append(req.cached_tokens)
retraction_counts.append(req.retraction_count)
queue_times.append(req.time_stats.get_queueing_time())
forward_entry_times.append(req.time_stats.forward_entry_time)
if req.time_stats.prefill_start_time > 0.0:
prefill_delays.append(
req.time_stats.prefill_start_time
- req.time_stats.forward_entry_time
)
else:
prefill_delays.append(None)
if (
req.time_stats.prefill_start_time > 0.0
and req.time_stats.prefill_end_time > 0.0
):
prefill_latencies.append(
req.time_stats.prefill_end_time
- req.time_stats.prefill_start_time
)
else:
prefill_latencies.append(None)
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
spec_accepted_tokens.append(req.spec_accepted_tokens)
......@@ -951,31 +979,35 @@ class SchedulerOutputProcessorMixin:
self.send_to_detokenizer.send_output(
BatchTokenIDOutput(
finished_reasons,
decoded_texts,
decode_ids_list,
read_offsets,
output_ids,
skip_special_tokens,
spaces_between_special_tokens,
no_stop_trim,
prompt_tokens,
completion_tokens,
cached_tokens,
spec_verify_ct,
spec_accepted_tokens,
input_token_logprobs_val,
input_token_logprobs_idx,
output_token_logprobs_val,
output_token_logprobs_idx,
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
spec_verify_ct=spec_verify_ct,
spec_accepted_tokens=spec_accepted_tokens,
queue_time=queue_times,
forward_entry_time=forward_entry_times,
prefill_delay=prefill_delays,
prefill_latency=prefill_latencies,
finished_reasons=finished_reasons,
decoded_texts=decoded_texts,
decode_ids=decode_ids_list,
read_offsets=read_offsets,
output_ids=output_ids,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
no_stop_trim=no_stop_trim,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cached_tokens=cached_tokens,
input_token_logprobs_val=input_token_logprobs_val,
input_token_logprobs_idx=input_token_logprobs_idx,
output_token_logprobs_val=output_token_logprobs_val,
output_token_logprobs_idx=output_token_logprobs_idx,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
output_top_logprobs_val=output_top_logprobs_val,
output_top_logprobs_idx=output_top_logprobs_idx,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
output_token_ids_logprobs_val=output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=output_token_ids_logprobs_idx,
output_token_entropy_val=None,
output_hidden_states=output_hidden_states,
rids=rids,
......@@ -994,6 +1026,10 @@ class SchedulerOutputProcessorMixin:
embeddings = []
prompt_tokens = []
cached_tokens = []
queue_times = []
forward_entry_times = []
prefill_delays = []
prefill_latencies = []
retraction_counts = []
for req in reqs:
if req.finished():
......@@ -1003,17 +1039,43 @@ class SchedulerOutputProcessorMixin:
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens)
queue_times.append(req.time_stats.get_queueing_time())
forward_entry_times.append(req.time_stats.forward_entry_time)
if req.time_stats.prefill_start_time > 0.0:
prefill_delays.append(
req.time_stats.prefill_start_time
- req.time_stats.forward_entry_time
)
else:
prefill_delays.append(None)
if (
req.time_stats.prefill_start_time > 0.0
and req.time_stats.prefill_end_time > 0.0
):
prefill_latencies.append(
req.time_stats.prefill_end_time
- req.time_stats.prefill_start_time
)
else:
prefill_latencies.append(None)
retraction_counts.append(req.retraction_count)
self.send_to_detokenizer.send_output(
BatchEmbeddingOutput(
finished_reasons,
embeddings,
prompt_tokens,
cached_tokens,
rids=rids,
queue_time=queue_times,
forward_entry_time=forward_entry_times,
prefill_delay=prefill_delays,
prefill_latency=prefill_latencies,
finished_reasons=finished_reasons,
embeddings=embeddings,
prompt_tokens=prompt_tokens,
cached_tokens=cached_tokens,
http_worker_ipcs=http_worker_ipcs,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
retraction_counts=retraction_counts,
rids=rids,
)
)
......@@ -136,6 +136,13 @@ class ReqState:
last_time: float = 0.0
last_completion_tokens: int = 1
# perf_counter equivalents for accurate time calculations
finished_time_perf: float = 0.0
first_token_time_perf: float = 0.0
request_scheduled_ts: float = 0.0
response_sent_ts: float = 0.0
# For streaming output
last_output_offset: int = 0
......@@ -911,6 +918,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
self.send_to_scheduler.send_pyobj(tokenized_obj)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
state.request_scheduled_ts = time.time()
self.rid_to_state[obj.rid] = state
trace_slice_end(
RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True
......@@ -968,6 +976,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.out_list = []
if state.finished:
# For non-streaming cases, response has not been sent yet (`response_sent_ts` has not been set yet).
# Record response sent time right before we log finished results and metrics.
if not state.response_sent_ts:
state.response_sent_ts = time.time()
out["meta_info"]["response_sent_ts"] = state.response_sent_ts
if self.log_requests:
max_length, skip_names, out_skip_names = self.log_request_metadata
if self.model_config.is_multimodal_gen:
......@@ -1011,6 +1024,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state.event.clear()
if obj.stream:
# Record response sent time right before we send response.
if not state.response_sent_ts:
state.response_sent_ts = time.time()
out["meta_info"]["response_sent_ts"] = state.response_sent_ts
yield out
else:
if (
......@@ -1418,6 +1435,27 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"total_retractions": recv_obj.retraction_counts[i],
}
if (
hasattr(recv_obj, "queue_time")
and recv_obj.queue_time
and recv_obj.queue_time[i] is not None
):
meta_info["queue_time"] = recv_obj.queue_time[i]
if (
hasattr(recv_obj, "prefill_delay")
and recv_obj.prefill_delay
and recv_obj.prefill_delay[i] is not None
):
meta_info["prefill_delay"] = recv_obj.prefill_delay[i]
if (
hasattr(recv_obj, "prefill_latency")
and recv_obj.prefill_latency
and recv_obj.prefill_latency[i] is not None
):
meta_info["prefill_latency"] = recv_obj.prefill_latency[i]
if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style(
meta_info,
......@@ -1483,8 +1521,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if self.server_args.speculative_algorithm:
self._calculate_spec_decoding_metrics(meta_info, recv_obj, i)
state.finished_time = time.time()
state.finished_time_perf = time.perf_counter()
meta_info["e2e_latency"] = state.finished_time - state.created_time
# Calculate timing metrics
self._calculate_timing_metrics(meta_info, state, recv_obj, i)
trace_req_finish(rid, ts=int(state.finished_time * 1e9))
del self.rid_to_state[rid]
......@@ -1687,6 +1729,57 @@ class TokenizerManager(TokenizerCommunicatorMixin):
recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i]
)
def _calculate_timing_metrics(
self,
meta_info: Dict[str, Any],
state: ReqState,
recv_obj: Union[
BatchStrOutput,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchTokenIDOutput,
],
i: int,
) -> None:
"""Calculate request-level timing metrics, such as inference time, decode throughput, and time per token."""
# Request timing timestamps.
if state.created_time > 0:
meta_info["request_received_ts"] = state.created_time
if state.request_scheduled_ts > 0:
meta_info["request_scheduled_ts"] = state.request_scheduled_ts
# For embeddings, there's no separate prefill phase, so omit `prefill_finished_ts`.
if (
not isinstance(recv_obj, BatchEmbeddingOutput)
and state.first_token_time > 0
):
meta_info["prefill_finished_ts"] = state.first_token_time
if state.response_sent_ts > 0:
meta_info["response_sent_ts"] = state.response_sent_ts
if state.finished_time > 0:
meta_info["decode_finished_ts"] = state.finished_time
# Inference time calculation.
if (
hasattr(recv_obj, "forward_entry_time")
and recv_obj.forward_entry_time
and recv_obj.forward_entry_time[i] is not None
and state.finished_time_perf > 0.0
):
forward_time = state.finished_time_perf - recv_obj.forward_entry_time[i]
meta_info["forward_time"] = forward_time
# Decode throughput, time per token calculation. Only calculated if TTFT is available.
if (
state.first_token_time_perf > 0.0
and state.finished_time_perf > 0.0
and not isinstance(recv_obj, BatchEmbeddingOutput)
and recv_obj.completion_tokens[i] > 0
):
decode_time = state.finished_time_perf - state.first_token_time_perf
completion_tokens = recv_obj.completion_tokens[i]
meta_info["decode_throughput"] = completion_tokens / decode_time
meta_info["time_per_token"] = decode_time / completion_tokens
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
completion_tokens = (
recv_obj.completion_tokens[i]
......@@ -1705,6 +1798,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
state.first_token_time = state.last_time = time.time()
state.first_token_time_perf = time.perf_counter()
state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token(
labels, state.first_token_time - state.created_time
......
......@@ -46,6 +46,8 @@ class TimeStats:
# TODO: correct set them
bootstrap_duration: float = 0.0
alloc_waiting_duration: float = 0.0
prefill_start_time: float = 0.0
prefill_end_time: float = 0.0
def get_queueing_time(self) -> float:
return self.forward_entry_time - self.wait_queue_entry_time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment