"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "d1b837f0ae6a0152d820194a181e809ffaef6864"
Unverified Commit d3af8c18 authored by Mark McLoughlin's avatar Mark McLoughlin Committed by GitHub
Browse files

[Core][Metrics][BugFix] Replace num_cached_tokens/num_external_computed_tokens...


[Core][Metrics][BugFix] Replace num_cached_tokens/num_external_computed_tokens with PrefillStats (#37460)

Related to `Counters can only be incremented by non-negative amounts`
error with the `vllm:prompt_tokens_by_source_total` metric.
Signed-off-by: default avatarMark McLoughlin <markmc@redhat.com>
Co-authored-by: default avatarOr Ozeri <or@ozery.com>
parent 25b3242d
...@@ -153,7 +153,6 @@ def test_prefix_caching_for_prefill_dedup(): ...@@ -153,7 +153,6 @@ def test_prefix_caching_for_prefill_dedup():
same_prompt=True, same_prompt=True,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
) )
requests_copy = requests.copy()
# Two requests with the same prompt. # Two requests with the same prompt.
req0 = requests.pop(0) req0 = requests.pop(0)
...@@ -167,26 +166,31 @@ def test_prefix_caching_for_prefill_dedup(): ...@@ -167,26 +166,31 @@ def test_prefix_caching_for_prefill_dedup():
# Make sure prefix caching de-duplicates the prompts in the same step, # Make sure prefix caching de-duplicates the prompts in the same step,
# so all the blocks except the last are shared between the two requests. # so all the blocks except the last are shared between the two requests.
assert len(sched_output.num_scheduled_tokens) == 2 assert len(sched_output.num_scheduled_tokens) == 2
num_blocks = num_prompt_tokens // BLOCK_SIZE assert sched_output.num_scheduled_tokens[req0.request_id] == num_prompt_tokens
assert req0.num_cached_tokens == 0 assert (
assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE sched_output.num_scheduled_tokens[req1.request_id]
== num_prompt_tokens % BLOCK_SIZE
)
sched_outputs.append(scheduler.schedule()) sched_outputs.append(scheduler.schedule())
while sched_outputs: while sched_outputs:
added_req = None
if requests: if requests:
scheduler.add_request(requests.pop(0)) added_req = requests.pop(0)
scheduler.add_request(added_req)
sched_output = sched_outputs.popleft() sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output) model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output) scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule() sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens: if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output) sched_outputs.append(sched_output)
if added_req:
assert (
sched_output.num_scheduled_tokens[added_req.request_id]
== num_prompt_tokens % BLOCK_SIZE
)
# Other requests scheduled after the two requests should also get
# prefix cache hit.
assert scheduler.get_num_unfinished_requests() == 0 assert scheduler.get_num_unfinished_requests() == 0
for req in requests_copy[1:]:
assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE
def test_prefix_caching_for_multi_turn(): def test_prefix_caching_for_multi_turn():
...@@ -243,12 +247,15 @@ def test_prefix_caching_for_multi_turn(): ...@@ -243,12 +247,15 @@ def test_prefix_caching_for_multi_turn():
# Schedule the next-turn requests. # Schedule the next-turn requests.
for req in next_turn_requests: for req in next_turn_requests:
scheduler.add_request(req) scheduler.add_request(req)
sched_outputs.append(scheduler.schedule()) sched_output = scheduler.schedule()
sched_outputs.append(sched_output)
# Make sure the next-turn requests get prefix cache hit by the previous # Make sure the next-turn requests get prefix cache hit by the previous
# requests. # requests.
for req in next_turn_requests: for req in next_turn_requests:
assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE assert sched_output.num_scheduled_tokens[req.request_id] == (
req.num_prompt_tokens % BLOCK_SIZE
)
def test_abort_request_when_structured_output_fsm_cannot_advance(): def test_abort_request_when_structured_output_fsm_cannot_advance():
......
...@@ -84,6 +84,7 @@ def test_incremental_detokenization( ...@@ -84,6 +84,7 @@ def test_incremental_detokenization(
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens, tokens_list=dummy_test_vectors.generation_tokens,
prompts_list=dummy_test_vectors.prompt_tokens,
request_ids=[req.request_id for req in requests], request_ids=[req.request_id for req in requests],
) )
...@@ -506,6 +507,7 @@ def test_logprobs_processor( ...@@ -506,6 +507,7 @@ def test_logprobs_processor(
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens, tokens_list=dummy_test_vectors.generation_tokens,
prompts_list=dummy_test_vectors.prompt_tokens,
generated_logprobs_raw=None generated_logprobs_raw=None
if num_sample_logprobs is None if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs, else dummy_test_vectors.generation_logprobs,
...@@ -691,6 +693,7 @@ def test_stop_token( ...@@ -691,6 +693,7 @@ def test_stop_token(
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=[generation_tokens], tokens_list=[generation_tokens],
prompts_list=dummy_test_vectors.prompt_tokens,
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None, generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None, prompt_logprobs_raw=None,
eos_token_id=sampling_params.eos_token_id, eos_token_id=sampling_params.eos_token_id,
...@@ -794,6 +797,7 @@ def test_stop_string( ...@@ -794,6 +797,7 @@ def test_stop_string(
engine_core = MockEngineCore( engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens, tokens_list=dummy_test_vectors.generation_tokens,
prompts_list=dummy_test_vectors.prompt_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs if num_sample_logprobs
else None, else None,
...@@ -917,6 +921,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -917,6 +921,7 @@ def test_iteration_stats(dummy_test_vectors):
engine_core = MockEngineCore( engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens, dummy_test_vectors.generation_tokens,
dummy_test_vectors.prompt_tokens,
request_ids=[req.request_id for req in requests], request_ids=[req.request_id for req in requests],
) )
...@@ -927,7 +932,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -927,7 +932,7 @@ def test_iteration_stats(dummy_test_vectors):
inactive_request = requests[num_active] inactive_request = requests[num_active]
# First iteration has 2 prefills. # First iteration has 2 prefills.
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs(num_active)
iteration_stats = IterationStats() iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
total_prompt_tokens = sum( total_prompt_tokens = sum(
...@@ -941,7 +946,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -941,7 +946,7 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active assert iteration_stats.num_generation_tokens == num_active
# Just decodes in this step. # Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs(num_active)
iteration_stats = IterationStats() iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
...@@ -951,7 +956,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -951,7 +956,7 @@ def test_iteration_stats(dummy_test_vectors):
# Add a new request - prefill and 2 decodes in this step. # Add a new request - prefill and 2 decodes in this step.
output_processor.add_request(inactive_request, None) output_processor.add_request(inactive_request, None)
num_active += 1 num_active += 1
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs(num_active)
iteration_stats = IterationStats() iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
...@@ -960,7 +965,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -960,7 +965,7 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active assert iteration_stats.num_generation_tokens == num_active
# Just decodes in this step. # Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active] outputs = engine_core.get_outputs(num_active)
iteration_stats = IterationStats() iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
...@@ -1003,6 +1008,7 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): ...@@ -1003,6 +1008,7 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
engine_core = MockEngineCore( engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens, dummy_test_vectors.generation_tokens,
dummy_test_vectors.prompt_tokens,
request_ids=[req.request_id for req in requests], request_ids=[req.request_id for req in requests],
) )
......
...@@ -11,6 +11,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast ...@@ -11,6 +11,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.metrics.stats import PrefillStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
GeneralTokenizerType: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast GeneralTokenizerType: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
...@@ -330,6 +331,7 @@ class MockEngineCore: ...@@ -330,6 +331,7 @@ class MockEngineCore:
def __init__( def __init__(
self, self,
tokens_list: list[list[int]], tokens_list: list[list[int]],
prompts_list: list[list[int]],
# For each request, for each sampled token offset, # For each request, for each sampled token offset,
# a tuple of # a tuple of
# (list of topk token ids, list of sample logprob vals, rank) # (list of topk token ids, list of sample logprob vals, rank)
...@@ -346,12 +348,13 @@ class MockEngineCore: ...@@ -346,12 +348,13 @@ class MockEngineCore:
) -> None: ) -> None:
self.num_requests = len(tokens_list) self.num_requests = len(tokens_list)
self.tokens_list = tokens_list self.tokens_list = tokens_list
self.current_idx = 0 self.prompts_list = prompts_list
self.generated_logprobs_raw = generated_logprobs_raw self.generated_logprobs_raw = generated_logprobs_raw
self.do_logprobs = generated_logprobs_raw is not None self.do_logprobs = generated_logprobs_raw is not None
self.prompt_logprobs_raw = prompt_logprobs_raw self.prompt_logprobs_raw = prompt_logprobs_raw
self.do_prompt_logprobs = prompt_logprobs_raw is not None self.do_prompt_logprobs = prompt_logprobs_raw is not None
self.request_finished = [False for _ in range(self.num_requests)] self.request_finished = [False for _ in range(self.num_requests)]
self.request_token_idx = [0 for _ in range(self.num_requests)]
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.stop_token_ids = stop_token_ids self.stop_token_ids = stop_token_ids
self.request_ids = ( self.request_ids = (
...@@ -360,14 +363,18 @@ class MockEngineCore: ...@@ -360,14 +363,18 @@ class MockEngineCore:
else [f"request-{i}" for i in range(self.num_requests)] else [f"request-{i}" for i in range(self.num_requests)]
) )
def get_outputs(self) -> list[EngineCoreOutput]: def get_outputs(self, num_active: int = -1) -> list[EngineCoreOutput]:
do_logprobs = self.do_logprobs do_logprobs = self.do_logprobs
do_prompt_logprobs = self.do_prompt_logprobs do_prompt_logprobs = self.do_prompt_logprobs
token_idx = self.current_idx
outputs = [] outputs = []
for req_idx, token_ids in enumerate(self.tokens_list): for req_idx, (token_ids, prompt_token_ids) in enumerate(
zip(self.tokens_list, self.prompts_list)
):
if num_active != -1 and req_idx >= num_active:
break
if not self.request_finished[req_idx]: if not self.request_finished[req_idx]:
token_idx = self.request_token_idx[req_idx]
if do_logprobs: if do_logprobs:
assert self.generated_logprobs_raw is not None assert self.generated_logprobs_raw is not None
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = ( (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
...@@ -381,19 +388,32 @@ class MockEngineCore: ...@@ -381,19 +388,32 @@ class MockEngineCore:
else: else:
logprobs = None logprobs = None
if do_prompt_logprobs: if do_prompt_logprobs:
if self.current_idx == 0: if token_idx == 0:
assert self.prompt_logprobs_raw is not None assert self.prompt_logprobs_raw is not None
prompt_logprobs = self.prompt_logprobs_raw[req_idx] prompt_logprobs = self.prompt_logprobs_raw[req_idx]
else: else:
prompt_logprobs = None prompt_logprobs = None
else: else:
prompt_logprobs = None prompt_logprobs = None
# Add prefill_stats on first output (prefill) for this request
if token_idx == 0:
prefill_stats = PrefillStats()
prefill_stats.set(
num_prompt_tokens=len(prompt_token_ids),
num_local_cached_tokens=0,
num_external_cached_tokens=0,
)
else:
prefill_stats = None
new_token_id = token_ids[token_idx] new_token_id = token_ids[token_idx]
output = EngineCoreOutput( output = EngineCoreOutput(
request_id=self.request_ids[req_idx], request_id=self.request_ids[req_idx],
new_token_ids=[new_token_id], new_token_ids=[new_token_id],
new_logprobs=logprobs, new_logprobs=logprobs,
new_prompt_logprobs_tensors=prompt_logprobs, new_prompt_logprobs_tensors=prompt_logprobs,
prefill_stats=prefill_stats,
) )
if token_idx == len(token_ids) - 1: if token_idx == len(token_ids) - 1:
output.finish_reason = FinishReason.LENGTH output.finish_reason = FinishReason.LENGTH
...@@ -407,5 +427,6 @@ class MockEngineCore: ...@@ -407,5 +427,6 @@ class MockEngineCore:
self.request_finished[req_idx] = True self.request_finished[req_idx] = True
outputs.append(output) outputs.append(output)
self.current_idx += 1 self.request_token_idx[req_idx] += 1
return outputs return outputs
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, PromptTokenStats, RequestStateStats from vllm.v1.metrics.stats import (
IterationStats,
PrefillStats,
PromptTokenStats,
RequestStateStats,
)
def test_iteration_stats_repr(): def test_iteration_stats_repr():
...@@ -114,15 +119,18 @@ def test_prompt_token_stats_all_computed(): ...@@ -114,15 +119,18 @@ def test_prompt_token_stats_all_computed():
stats = PromptTokenStats() stats = PromptTokenStats()
# Case 1: No caching (All tokens computed locally) # Case 1: No caching (All tokens computed locally)
stats.update_from_output( prefill_stats = PrefillStats()
num_cached_tokens=0, prefill_stats.set(
num_external_computed_tokens=0, num_prompt_tokens=1000,
prompt_len=1000, num_local_cached_tokens=0,
num_external_cached_tokens=0,
) )
stats.update_from_output(prefill_stats)
assert stats.computed == 1000 assert stats.computed == 1000
assert stats.local_cache_hit == 0 assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 0 assert stats.external_kv_transfer == 0
assert stats.cached_tokens == 0
assert stats.total == 1000 assert stats.total == 1000
...@@ -131,15 +139,19 @@ def test_prompt_token_stats_partial_local_cache(): ...@@ -131,15 +139,19 @@ def test_prompt_token_stats_partial_local_cache():
stats = PromptTokenStats() stats = PromptTokenStats()
# Case 2: Partial local cache # Case 2: Partial local cache
stats.update_from_output( prefill_stats = PrefillStats()
num_cached_tokens=300, prefill_stats.set(
num_external_computed_tokens=0, num_prompt_tokens=1000,
prompt_len=1000, num_local_cached_tokens=300,
num_external_cached_tokens=0,
) )
stats.update_from_output(prefill_stats)
assert stats.computed == 700 assert stats.computed == 700
assert stats.local_cache_hit == 300 assert stats.local_cache_hit == 300
assert stats.external_kv_transfer == 0 assert stats.external_kv_transfer == 0
assert stats.cached_tokens == 300
assert stats.total == 1000
def test_prompt_token_stats_partial_external_transfer(): def test_prompt_token_stats_partial_external_transfer():
...@@ -147,15 +159,19 @@ def test_prompt_token_stats_partial_external_transfer(): ...@@ -147,15 +159,19 @@ def test_prompt_token_stats_partial_external_transfer():
stats = PromptTokenStats() stats = PromptTokenStats()
# Case 3: Partial external transfer # Case 3: Partial external transfer
stats.update_from_output( prefill_stats = PrefillStats()
num_cached_tokens=500, prefill_stats.set(
num_external_computed_tokens=500, num_prompt_tokens=1000,
prompt_len=1000, num_local_cached_tokens=0,
num_external_cached_tokens=500,
) )
stats.update_from_output(prefill_stats)
assert stats.computed == 500 assert stats.computed == 500
assert stats.local_cache_hit == 0 assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 500 assert stats.external_kv_transfer == 500
assert stats.cached_tokens == 500
assert stats.total == 1000
def test_prompt_token_stats_mixed_sources(): def test_prompt_token_stats_mixed_sources():
...@@ -163,47 +179,60 @@ def test_prompt_token_stats_mixed_sources(): ...@@ -163,47 +179,60 @@ def test_prompt_token_stats_mixed_sources():
stats = PromptTokenStats() stats = PromptTokenStats()
# Case 4: Mixed sources # Case 4: Mixed sources
stats.update_from_output( prefill_stats = PrefillStats()
num_cached_tokens=600, prefill_stats.set(
num_external_computed_tokens=200, num_prompt_tokens=1000,
prompt_len=1000, num_local_cached_tokens=400,
num_external_cached_tokens=200,
) )
stats.update_from_output(prefill_stats)
assert stats.computed == 400 assert stats.computed == 400
assert stats.local_cache_hit == 400 assert stats.local_cache_hit == 400
assert stats.external_kv_transfer == 200 assert stats.external_kv_transfer == 200
assert stats.cached_tokens == 600
assert stats.total == 1000
def test_prompt_token_stats_full_local_cache_recompute(): def test_prompt_token_stats_full_local_cache_recompute():
"""Test full local cache triggers last token recomputation. """Test full local cache triggers last token recomputation.
When all tokens are cached, the scheduler reduces num_cached_tokens by 1 When all tokens are cached, the scheduler forces the model to recompute
to force the model to recompute the last token. the last token (num_computed_tokens=1), with the rest from cache.
""" """
stats = PromptTokenStats() stats = PromptTokenStats()
# Case 5: Full local cache (999 cached after reduction, 1 recomputed) # Case 5: Full local cache (999 cached, 1 recomputed)
stats.update_from_output( prefill_stats = PrefillStats()
num_cached_tokens=999, prefill_stats.set(
num_external_computed_tokens=0, num_prompt_tokens=1000,
prompt_len=1000, num_local_cached_tokens=999,
num_external_cached_tokens=0,
) )
stats.update_from_output(prefill_stats)
assert stats.computed == 1 assert stats.computed == 1
assert stats.local_cache_hit == 999 assert stats.local_cache_hit == 999
assert stats.external_kv_transfer == 0
assert stats.cached_tokens == 999
assert stats.total == 1000
def test_prompt_token_stats_full_external_transfer_recompute(): def test_prompt_token_stats_full_external_transfer_recompute():
"""Test full external transfer triggers last token recomputation.""" """Test full external transfer triggers last token recomputation."""
stats = PromptTokenStats() stats = PromptTokenStats()
# Case 6: Full external transfer (999 cached after reduction, 1 recomputed) # Case 6: Full external transfer (999 from external, 1 recomputed)
stats.update_from_output( prefill_stats = PrefillStats()
num_cached_tokens=999, prefill_stats.set(
num_external_computed_tokens=999, num_prompt_tokens=1000,
prompt_len=1000, num_local_cached_tokens=0,
num_external_cached_tokens=999,
) )
stats.update_from_output(prefill_stats)
assert stats.computed == 1 assert stats.computed == 1
assert stats.local_cache_hit == 0 assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 999 assert stats.external_kv_transfer == 999
assert stats.cached_tokens == 999
assert stats.total == 1000
...@@ -629,7 +629,6 @@ class Scheduler(SchedulerInterface): ...@@ -629,7 +629,6 @@ class Scheduler(SchedulerInterface):
step_skipped_waiting.prepend_request(request) step_skipped_waiting.prepend_request(request)
continue continue
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens
connector_prefix_cache_queries = ( connector_prefix_cache_queries = (
...@@ -642,6 +641,15 @@ class Scheduler(SchedulerInterface): ...@@ -642,6 +641,15 @@ class Scheduler(SchedulerInterface):
num_new_local_computed_tokens + num_external_computed_tokens num_new_local_computed_tokens + num_external_computed_tokens
) )
assert num_computed_tokens <= request.num_tokens assert num_computed_tokens <= request.num_tokens
# Track first scheduled prefill, not post-preemption repeat prefills
if request.prefill_stats is not None:
assert num_computed_tokens <= request.num_prompt_tokens
request.prefill_stats.set(
num_prompt_tokens=request.num_prompt_tokens,
num_local_cached_tokens=num_new_local_computed_tokens,
num_external_cached_tokens=num_external_computed_tokens,
)
else: else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0 # KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed. # after async KV recvs are completed.
...@@ -826,9 +834,6 @@ class Scheduler(SchedulerInterface): ...@@ -826,9 +834,6 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
...@@ -1466,10 +1471,9 @@ class Scheduler(SchedulerInterface): ...@@ -1466,10 +1471,9 @@ class Scheduler(SchedulerInterface):
pooling_output=pooler_output, pooling_output=pooler_output,
stop_reason=request.stop_reason, stop_reason=request.stop_reason,
events=request.take_events(), events=request.take_events(),
prefill_stats=request.take_prefill_stats(),
kv_transfer_params=kv_transfer_params, kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts, routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits, num_nans_in_logits=request.num_nans_in_logits,
) )
...@@ -1496,7 +1500,6 @@ class Scheduler(SchedulerInterface): ...@@ -1496,7 +1500,6 @@ class Scheduler(SchedulerInterface):
finish_reason=request.get_finished_reason(), finish_reason=request.get_finished_reason(),
events=request.take_events(), events=request.take_events(),
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
) )
) )
...@@ -2071,10 +2074,6 @@ class Scheduler(SchedulerInterface): ...@@ -2071,10 +2074,6 @@ class Scheduler(SchedulerInterface):
if request.num_computed_tokens == request.num_tokens: if request.num_computed_tokens == request.num_tokens:
request.num_computed_tokens = request.num_tokens - 1 request.num_computed_tokens = request.num_tokens - 1
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
self.finished_recving_kv_req_ids.remove(request.request_id) self.finished_recving_kv_req_ids.remove(request.request_id)
def _try_promote_blocked_waiting_request(self, request: Request) -> bool: def _try_promote_blocked_waiting_request(self, request: Request) -> bool:
...@@ -2221,7 +2220,7 @@ class Scheduler(SchedulerInterface): ...@@ -2221,7 +2220,7 @@ class Scheduler(SchedulerInterface):
req_num_computed_tokens - request.num_computed_tokens req_num_computed_tokens - request.num_computed_tokens
) )
total_affected_tokens += num_affected_tokens total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
# collect invalid block and all downstream dependent blocks # collect invalid block and all downstream dependent blocks
if evict_blocks: if evict_blocks:
blocks_to_evict.update(req_block_ids[idx:]) blocks_to_evict.update(req_block_ids[idx:])
......
...@@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest ...@@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import PrefillStats, SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.serial_utils import UtilityResult from vllm.v1.serial_utils import UtilityResult
...@@ -171,10 +171,9 @@ class EngineCoreOutput( ...@@ -171,10 +171,9 @@ class EngineCoreOutput(
kv_transfer_params: dict[str, Any] | None = None kv_transfer_params: dict[str, Any] | None = None
trace_headers: Mapping[str, str] | None = None trace_headers: Mapping[str, str] | None = None
# The number of tokens with prefix cache hits (local + external).
num_cached_tokens: int = 0 prefill_stats: PrefillStats | None = None
# The number of tokens computed remotely (original count from connector).
num_external_computed_tokens: int = 0
routed_experts: np.ndarray | None = None routed_experts: np.ndarray | None = None
# The number of NaNs in logits. # The number of NaNs in logits.
# A value greater than 0 indicates that the output is corrupted. # A value greater than 0 indicates that the output is corrupted.
......
...@@ -617,8 +617,13 @@ class OutputProcessor: ...@@ -617,8 +617,13 @@ class OutputProcessor:
stop_reason = engine_core_output.stop_reason stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params kv_transfer_params = engine_core_output.kv_transfer_params
routed_experts = engine_core_output.routed_experts routed_experts = engine_core_output.routed_experts
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False if req_state.is_prefilling:
if engine_core_output.prefill_stats is not None:
req_state.num_cached_tokens = (
engine_core_output.prefill_stats.num_cached_tokens
)
req_state.is_prefilling = False
if pooling_output is None: if pooling_output is None:
assert req_state.detokenizer is not None assert req_state.detokenizer is not None
...@@ -776,7 +781,6 @@ class OutputProcessor: ...@@ -776,7 +781,6 @@ class OutputProcessor:
engine_core_output, engine_core_output,
engine_core_timestamp, engine_core_timestamp,
req_state.is_prefilling, req_state.is_prefilling,
req_state.prompt_len,
req_state.stats, req_state.stats,
self.lora_states, self.lora_states,
req_state.lora_name, req_state.lora_name,
......
...@@ -239,6 +239,40 @@ class FinishedRequestStats: ...@@ -239,6 +239,40 @@ class FinishedRequestStats:
num_cached_tokens: int = 0 num_cached_tokens: int = 0
@dataclass
class PrefillStats:
"""Breakdown of a scheduled prefill computation.
Fields:
num_prompt_tokens: Total number of tokens to be prefilled.
num_computed_tokens: Tokens to be prefilled locally (actual compute work).
num_cached_tokens: Tokens to be prefilled without actual compute work.
num_local_cached_tokens: Tokens to be prefilled from local prefix cache.
num_external_cached_tokens: Tokens to be prefilled from external KV transfer.
"""
num_prompt_tokens: int = 0
num_computed_tokens: int = 0
num_cached_tokens: int = 0
num_local_cached_tokens: int = 0
num_external_cached_tokens: int = 0
def set(
self,
num_prompt_tokens: int,
num_local_cached_tokens: int,
num_external_cached_tokens: int,
):
num_cached_tokens = num_local_cached_tokens + num_external_cached_tokens
assert num_cached_tokens <= num_prompt_tokens
self.num_prompt_tokens = num_prompt_tokens
self.num_computed_tokens = num_prompt_tokens - num_cached_tokens
self.num_cached_tokens = num_cached_tokens
self.num_local_cached_tokens = num_local_cached_tokens
self.num_external_cached_tokens = num_external_cached_tokens
@dataclass @dataclass
class PromptTokenStats: class PromptTokenStats:
"""Breakdown of prompt tokens by source. """Breakdown of prompt tokens by source.
...@@ -267,28 +301,14 @@ class PromptTokenStats: ...@@ -267,28 +301,14 @@ class PromptTokenStats:
cached_tokens: int = 0 cached_tokens: int = 0
total: int = 0 total: int = 0
def update_from_output( def update_from_output(self, prefill_stats: PrefillStats) -> None:
self,
num_cached_tokens: int,
num_external_computed_tokens: int,
prompt_len: int,
) -> None:
"""Update stats from a prefill output.""" """Update stats from a prefill output."""
self.computed += prompt_len - num_cached_tokens self.computed += prefill_stats.num_computed_tokens
self.external_kv_transfer += num_external_computed_tokens self.cached_tokens += prefill_stats.num_cached_tokens
# FIXME(yifan): local_cache_hit can go negative after preemption. self.total += prefill_stats.num_prompt_tokens
# num_cached_tokens is a one-time snapshot from first scheduling and
# is never reset on preemption, while num_external_computed_tokens is self.local_cache_hit += prefill_stats.num_local_cached_tokens
# overwritten on re-scheduling. If CPU offload finds more tokens on self.external_kv_transfer += prefill_stats.num_external_cached_tokens
# the second pass than the original total, the subtraction underflows.
# A fundamental fix is to track the first-time num_external_computed_tokens
# as a separate metric rather than reusing num_external_computed_tokens
# for metric directly.
self.local_cache_hit += max(
0, (num_cached_tokens - num_external_computed_tokens)
)
self.cached_tokens += num_cached_tokens
self.total += prompt_len
def get_by_source(self, source: str) -> int: def get_by_source(self, source: str) -> int:
"""Get token count by source label.""" """Get token count by source label."""
...@@ -335,7 +355,6 @@ class IterationStats: ...@@ -335,7 +355,6 @@ class IterationStats:
output: "EngineCoreOutput", output: "EngineCoreOutput",
engine_core_timestamp: float, engine_core_timestamp: float,
is_prefilling: bool, is_prefilling: bool,
prompt_len: int,
req_stats: RequestStateStats, req_stats: RequestStateStats,
lora_states: "LoRARequestStates", lora_states: "LoRARequestStates",
lora_name: str | None, lora_name: str | None,
...@@ -344,11 +363,8 @@ class IterationStats: ...@@ -344,11 +363,8 @@ class IterationStats:
self.num_generation_tokens += num_new_generation_tokens self.num_generation_tokens += num_new_generation_tokens
if is_prefilling: if is_prefilling:
self.prompt_token_stats.update_from_output( if output.prefill_stats is not None:
num_cached_tokens=output.num_cached_tokens, self.prompt_token_stats.update_from_output(output.prefill_stats)
num_external_computed_tokens=output.num_external_computed_tokens,
prompt_len=prompt_len,
)
first_token_latency = self._time_since(req_stats.arrival_time) first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency) self.time_to_first_tokens_iter.append(first_token_latency)
......
...@@ -20,6 +20,7 @@ from vllm.v1.engine import ( ...@@ -20,6 +20,7 @@ from vllm.v1.engine import (
EngineCoreRequest, EngineCoreRequest,
FinishReason, FinishReason,
) )
from vllm.v1.metrics.stats import PrefillStats
from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
...@@ -145,9 +146,6 @@ class Request: ...@@ -145,9 +146,6 @@ class Request:
self.all_token_ids = ConstantList(self._all_token_ids) self.all_token_ids = ConstantList(self._all_token_ids)
# trace_headers # trace_headers
self.trace_headers = trace_headers self.trace_headers = trace_headers
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
# True if this request is scheduled as a non-final prefill chunk. # True if this request is scheduled as a non-final prefill chunk.
self.is_prefill_chunk = False self.is_prefill_chunk = False
...@@ -159,8 +157,7 @@ class Request: ...@@ -159,8 +157,7 @@ class Request:
# The number of times this request has been preempted by the scheduler. # The number of times this request has been preempted by the scheduler.
self.num_preemptions = 0 self.num_preemptions = 0
# The number of tokens that have been computed remotely. self.prefill_stats: PrefillStats | None = PrefillStats()
self.num_external_computed_tokens = 0
self.block_hashes: list[BlockHash] = [] self.block_hashes: list[BlockHash] = []
# Store the block hasher without binding self to avoid creating a # Store the block hasher without binding self to avoid creating a
...@@ -278,6 +275,13 @@ class Request: ...@@ -278,6 +275,13 @@ class Request:
events, self.events = self.events, [] events, self.events = self.events, []
return events return events
def take_prefill_stats(self) -> PrefillStats | None:
if self.prefill_stats is None:
return None
prefill_stats = self.prefill_stats
self.prefill_stats = None
return prefill_stats
def __lt__(self, other: "Request") -> bool: def __lt__(self, other: "Request") -> bool:
""" """
Compare two requests based on priority, arrival time, and request ID. Compare two requests based on priority, arrival time, and request ID.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment