Unverified Commit 4403e3ed authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

[Metrics] Add labeled prompt token metrics for P/D disaggregation (#33290)



Add labeled Prometheus metrics to distinguish where prompt tokens come
from in P/D disaggregated deployments.

In P/D disaggregation, decode instances receive KV cache from prefill instances.
Currently, decode reports inflated prompt throughput because it counts all
prompt tokens as "computed", even though most were transferred.

This PR adds labeled metrics so users can understand actual compute work vs
transferred work:

vllm:prompt_tokens_by_source_total{source="local_compute"}        # Tokens prefilled locally
vllm:prompt_tokens_by_source_total{source="external_kv_transfer"} # Tokens received via KV transfer  
vllm:prompt_tokens_by_source_total{source="local_cache_hit"}      # Tokens from local prefix cache
vllm:prompt_tokens_cached_total                                    # Total cached (local + external, -1 when all 
Signed-off-by: default avatarZhanqiu Hu <zh338@cornell.edu>
parent 08e09499
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
from vllm.v1.metrics.stats import IterationStats, PromptTokenStats, RequestStateStats
def test_iteration_stats_repr():
......@@ -107,3 +107,105 @@ def test_prefill_kv_computed_edge_cases():
finished_req2.num_cached_tokens, 0
)
assert prefill_kv_computed2 == 0 # All cached, nothing computed
def test_prompt_token_stats_all_computed():
"""Test all tokens computed locally, no caching."""
stats = PromptTokenStats()
# Case 1: No caching (All tokens computed locally)
stats.update_from_output(
num_cached_tokens=0,
num_external_computed_tokens=0,
prompt_len=1000,
)
assert stats.computed == 1000
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 0
assert stats.total == 1000
def test_prompt_token_stats_partial_local_cache():
"""Test partial local prefix cache hit."""
stats = PromptTokenStats()
# Case 2: Partial local cache
stats.update_from_output(
num_cached_tokens=300,
num_external_computed_tokens=0,
prompt_len=1000,
)
assert stats.computed == 700
assert stats.local_cache_hit == 300
assert stats.external_kv_transfer == 0
def test_prompt_token_stats_partial_external_transfer():
"""Test partial external KV transfer."""
stats = PromptTokenStats()
# Case 3: Partial external transfer
stats.update_from_output(
num_cached_tokens=500,
num_external_computed_tokens=500,
prompt_len=1000,
)
assert stats.computed == 500
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 500
def test_prompt_token_stats_mixed_sources():
"""Test mix of local cache and external transfer."""
stats = PromptTokenStats()
# Case 4: Mixed sources
stats.update_from_output(
num_cached_tokens=600,
num_external_computed_tokens=200,
prompt_len=1000,
)
assert stats.computed == 400
assert stats.local_cache_hit == 400
assert stats.external_kv_transfer == 200
def test_prompt_token_stats_full_local_cache_recompute():
"""Test full local cache triggers last token recomputation.
When all tokens are cached, the scheduler reduces num_cached_tokens by 1
to force the model to recompute the last token.
"""
stats = PromptTokenStats()
# Case 5: Full local cache (999 cached after reduction, 1 recomputed)
stats.update_from_output(
num_cached_tokens=999,
num_external_computed_tokens=0,
prompt_len=1000,
)
assert stats.computed == 1
assert stats.local_cache_hit == 1000
assert stats.recomputed_tokens == 1
def test_prompt_token_stats_full_external_transfer_recompute():
"""Test full external transfer triggers last token recomputation."""
stats = PromptTokenStats()
# Case 6: Full external transfer (999 cached after reduction, 1 recomputed)
stats.update_from_output(
num_cached_tokens=999,
num_external_computed_tokens=1000,
prompt_len=1000,
)
assert stats.computed == 1
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 1000
assert stats.recomputed_tokens == 1
......@@ -1378,6 +1378,7 @@ class Scheduler(SchedulerInterface):
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
)
......
......@@ -139,8 +139,10 @@ class EngineCoreOutput(
kv_transfer_params: dict[str, Any] | None = None
trace_headers: Mapping[str, str] | None = None
# The number of tokens with prefix cache hits.
# The number of tokens with prefix cache hits (local + external).
num_cached_tokens: int = 0
# The number of tokens computed remotely (original count from connector).
num_external_computed_tokens: int = 0
routed_experts: np.ndarray | None = None
# The number of NaNs in logits.
# A value greater than 0 indicates that the output is corrupted.
......
......@@ -25,6 +25,7 @@ from vllm.v1.metrics.stats import (
CachingMetrics,
IterationStats,
MultiModalCacheStats,
PromptTokenStats,
SchedulerStats,
)
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
......@@ -136,7 +137,8 @@ class LoggingStatLogger(StatLoggerBase):
def _track_iteration_stats(self, iteration_stats: IterationStats):
# Save tracked stats for token counters.
self.num_prompt_tokens += iteration_stats.num_prompt_tokens
# Use computed tokens for prompt throughput (excludes cached/transferred)
self.num_prompt_tokens += iteration_stats.prompt_token_stats.computed
self.num_generation_tokens += iteration_stats.num_generation_tokens
self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs
self.num_preemptions += iteration_stats.num_preempted_reqs
......@@ -590,6 +592,41 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
counter_prompt_tokens, engine_indexes, model_name
)
# Labeled prompt token counters by source
counter_prompt_tokens_by_source = self._counter_cls(
name="vllm:prompt_tokens_by_source",
documentation="Number of prompt tokens by source.",
labelnames=labelnames + ["source"],
)
self.counter_prompt_tokens_by_source: dict[str, dict[int, Counter]] = {}
for source in PromptTokenStats.ALL_SOURCES:
self.counter_prompt_tokens_by_source[source] = {
idx: counter_prompt_tokens_by_source.labels(
model_name, str(idx), source
)
for idx in engine_indexes
}
# Cached prompt tokens counter
counter_prompt_tokens_cached = self._counter_cls(
name="vllm:prompt_tokens_cached",
documentation="Number of cached prompt tokens (local + external).",
labelnames=labelnames,
)
self.counter_prompt_tokens_cached = make_per_engine(
counter_prompt_tokens_cached, engine_indexes, model_name
)
# Recomputed tokens (last token recomputed when entire prompt is cached)
counter_prompt_tokens_recomputed = self._counter_cls(
name="vllm:prompt_tokens_recomputed",
documentation="Number of cached tokens recomputed for forward pass.",
labelnames=labelnames,
)
self.counter_prompt_tokens_recomputed = make_per_engine(
counter_prompt_tokens_recomputed, engine_indexes, model_name
)
counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens",
documentation="Number of generation tokens processed.",
......@@ -1070,6 +1107,14 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
iteration_stats.num_preempted_reqs
)
self.counter_prompt_tokens[engine_idx].inc(iteration_stats.num_prompt_tokens)
# Labeled prompt token counters by source
pts = iteration_stats.prompt_token_stats
for source in PromptTokenStats.ALL_SOURCES:
self.counter_prompt_tokens_by_source[source][engine_idx].inc(
pts.get_by_source(source)
)
self.counter_prompt_tokens_cached[engine_idx].inc(pts.cached_tokens)
self.counter_prompt_tokens_recomputed[engine_idx].inc(pts.recomputed_tokens)
self.counter_generation_tokens[engine_idx].inc(
iteration_stats.num_generation_tokens
)
......
......@@ -231,13 +231,76 @@ class FinishedRequestStats:
num_cached_tokens: int = 0
@dataclass
class PromptTokenStats:
"""Breakdown of prompt tokens by source.
Fields:
computed: Tokens prefilled locally (actual compute work).
local_cache_hit: Tokens from local prefix cache.
external_kv_transfer: Tokens from external KV transfer.
cached_tokens: Tokens skipped during prefill (from scheduler).
recomputed_tokens: Cached tokens that were recomputed (see below).
total: Total prompt tokens.
Invariants:
computed + local_cache_hit + external_kv_transfer - recomputed_tokens = total
local_cache_hit + external_kv_transfer - recomputed_tokens = cached_tokens
"""
ALL_SOURCES: tuple[str, ...] = (
"local_compute",
"local_cache_hit",
"external_kv_transfer",
)
computed: int = 0
local_cache_hit: int = 0
external_kv_transfer: int = 0
cached_tokens: int = 0
recomputed_tokens: int = 0
total: int = 0
def update_from_output(
self,
num_cached_tokens: int,
num_external_computed_tokens: int,
prompt_len: int,
) -> None:
"""Update stats from a prefill output."""
# When all tokens are cached, the scheduler reduces num_cached_tokens
# by 1 to force the model to recompute the last token, since the model
# needs at least one input token to run a forward pass.
recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0
self.computed += prompt_len - num_cached_tokens
self.external_kv_transfer += num_external_computed_tokens
self.local_cache_hit += (
num_cached_tokens + recomputed - num_external_computed_tokens
)
self.cached_tokens += num_cached_tokens
self.recomputed_tokens += recomputed
self.total += prompt_len
def get_by_source(self, source: str) -> int:
"""Get token count by source label."""
source_map = {
"local_compute": self.computed,
"local_cache_hit": self.local_cache_hit,
"external_kv_transfer": self.external_kv_transfer,
}
if source not in source_map:
raise ValueError(f"Unknown source: {source}")
return source_map[source]
class IterationStats:
"""Stats associated with a single set of EngineCoreOutputs."""
def __init__(self):
self.iteration_timestamp = time.time()
self.num_generation_tokens = 0
self.num_prompt_tokens = 0
self.prompt_token_stats = PromptTokenStats()
self.num_preempted_reqs = 0
self.finished_requests: list[FinishedRequestStats] = []
self.max_num_generation_tokens_iter: list[int] = []
......@@ -250,6 +313,11 @@ class IterationStats:
field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
return f"{self.__class__.__name__}({field_to_value_str})"
@property
def num_prompt_tokens(self) -> int:
"""Total prompt tokens (for backward compatibility)."""
return self.prompt_token_stats.total
def _time_since(self, start: float) -> float:
"""Calculate an interval relative to this iteration's timestamp."""
return self.iteration_timestamp - start
......@@ -268,7 +336,11 @@ class IterationStats:
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
self.num_prompt_tokens += prompt_len
self.prompt_token_stats.update_from_output(
num_cached_tokens=output.num_cached_tokens,
num_external_computed_tokens=output.num_external_computed_tokens,
prompt_len=prompt_len,
)
first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment