"docs/vscode:/vscode.git/clone" did not exist on "42bb201fd6f79d6ed2e28e0263ffa891cd993c4c"
Unverified Commit 2abd9759 authored by Mark McLoughlin's avatar Mark McLoughlin Committed by GitHub
Browse files

[KV Connector][Metrics] Do not count local prefix cache hits in connector queries (#30522)


Signed-off-by: default avatarMark McLoughlin <markmc@redhat.com>
parent 6abb0454
...@@ -1136,7 +1136,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): ...@@ -1136,7 +1136,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
) )
scheduler.update_from_output(output, EMPTY_OUTPUT) initial_ecos = scheduler.update_from_output(output, EMPTY_OUTPUT)
# Simulate KV transfer completion using KVConnectorOutput.finished_recving # Simulate KV transfer completion using KVConnectorOutput.finished_recving
output = scheduler.schedule() output = scheduler.schedule()
...@@ -1156,6 +1156,8 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): ...@@ -1156,6 +1156,8 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
for req_id in req_ids: for req_id in req_ids:
assert req_id in scheduler.finished_recving_kv_req_ids assert req_id in scheduler.finished_recving_kv_req_ids
return initial_ecos
@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.parametrize("is_async", [False, True])
def test_kv_connector_basic(is_async: bool): def test_kv_connector_basic(is_async: bool):
...@@ -1286,29 +1288,72 @@ def test_kv_connector_basic(is_async: bool): ...@@ -1286,29 +1288,72 @@ def test_kv_connector_basic(is_async: bool):
@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.parametrize("is_async", [False, True])
def test_external_prefix_cache_metrics(is_async: bool): @pytest.mark.parametrize("local_cache_hits", [False, True])
def test_external_prefix_cache_metrics(is_async: bool, local_cache_hits: bool):
""" """
Verify connector prefix cache metrics are updated Verify connector prefix cache metrics are updated
correctly when the scheduler processes requests with KV connector hits. correctly when the scheduler processes requests with KV connector hits.
""" """
# Setup Scheduler. BLOCK_SIZE = 16
if local_cache_hits:
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 # 32 tokens
NUM_LOCAL_HITS = NUM_MATCHED_NEW_TOKENS * 2 # 64 tokens
NUM_REQUESTS = 1
NUM_TOKENS = NUM_LOCAL_HITS * 2 # 128 tokens
else:
NUM_MATCHED_NEW_TOKENS = 4 NUM_MATCHED_NEW_TOKENS = 4
NUM_LOCAL_HITS = 0
NUM_REQUESTS = 2
NUM_TOKENS = 8 # 8 tokens
# Setup Scheduler.
scheduler = create_scheduler( scheduler = create_scheduler(
enable_prefix_caching=False, enable_prefix_caching=local_cache_hits,
use_kv_connector=mock_kv( use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
), ),
block_size=BLOCK_SIZE,
) )
# --- Prepare simple requests --- if local_cache_hits:
NUM_REQUESTS = 2 # First, establish local cache by running a request to completion
NUM_TOKENS = 8 requests = create_requests(
num_requests=1,
num_tokens=NUM_LOCAL_HITS,
max_tokens=2,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
# Run first request to completion to establish local cache
output = scheduler.schedule()
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# --- Prepare test requests ---
MAX_TOKENS = 2 MAX_TOKENS = 2
requests = create_requests( requests = create_requests(
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
) )
req_ids = [] req_ids = []
req_to_index = {} req_to_index = {}
...@@ -1317,8 +1362,9 @@ def test_external_prefix_cache_metrics(is_async: bool): ...@@ -1317,8 +1362,9 @@ def test_external_prefix_cache_metrics(is_async: bool):
req_ids.append(request.request_id) req_ids.append(request.request_id)
req_to_index[request.request_id] = i req_to_index[request.request_id] = i
initial_ecos = None
if is_async: if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids) initial_ecos = _step_until_kv_transfer_finished(scheduler, req_ids)
# --- Trigger scheduling and simulate model output --- # --- Trigger scheduling and simulate model output ---
output = scheduler.schedule() output = scheduler.schedule()
...@@ -1338,10 +1384,23 @@ def test_external_prefix_cache_metrics(is_async: bool): ...@@ -1338,10 +1384,23 @@ def test_external_prefix_cache_metrics(is_async: bool):
assert ecos is not None and len(ecos) > 0 assert ecos is not None and len(ecos) > 0
assert ecos[0].scheduler_stats is not None assert ecos[0].scheduler_stats is not None
if local_cache_hits:
# For async, local cache stats come from the first step
if initial_ecos:
local_stats = initial_ecos[0].scheduler_stats.prefix_cache_stats
else:
local_stats = ecos[0].scheduler_stats.prefix_cache_stats
assert local_stats is not None
assert local_stats.queries == NUM_TOKENS * NUM_REQUESTS
assert local_stats.hits == NUM_LOCAL_HITS * NUM_REQUESTS
if initial_ecos:
external_stats = initial_ecos[0].scheduler_stats.connector_prefix_cache_stats
else:
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
assert external_stats is not None assert external_stats is not None
assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS assert external_stats.queries == (NUM_TOKENS - NUM_LOCAL_HITS) * NUM_REQUESTS
assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS
assert external_stats.requests == NUM_REQUESTS assert external_stats.requests == NUM_REQUESTS
assert external_stats.preempted_requests == 0 assert external_stats.preempted_requests == 0
......
...@@ -281,6 +281,17 @@ def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler): ...@@ -281,6 +281,17 @@ def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
f"(hash should be None), but hash is still {block.block_hash}" f"(hash should be None), but hash is still {block.block_hash}"
) )
# Verify connector prefix cache stats:
# - queries = num_prompt_tokens (total tokens not in local cache)
# - hits = num_external_computed_tokens (tokens loaded externally)
assert engine_outputs.scheduler_stats is not None
stats = engine_outputs.scheduler_stats
assert stats.connector_prefix_cache_stats is not None
conn_stats = stats.connector_prefix_cache_stats
assert conn_stats.requests == 1
assert conn_stats.queries == num_prompt_tokens
assert conn_stats.hits == num_external_computed_tokens
def test_async_recompute_blocks_not_cached_when_invalid( def test_async_recompute_blocks_not_cached_when_invalid(
recompute_scheduler: Scheduler, recompute_scheduler: Scheduler,
...@@ -364,7 +375,9 @@ def test_async_recompute_blocks_not_cached_when_invalid( ...@@ -364,7 +375,9 @@ def test_async_recompute_blocks_not_cached_when_invalid(
with patch.object( with patch.object(
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
): ):
recompute_scheduler.update_from_output(scheduler_output, model_runner_output) outputs = recompute_scheduler.update_from_output(
scheduler_output, model_runner_output
)
# verify evict_blocks was NOT called (async blocks excluded from eviction) # verify evict_blocks was NOT called (async blocks excluded from eviction)
assert len(evict_blocks_calls) == 0, ( assert len(evict_blocks_calls) == 0, (
...@@ -386,6 +399,19 @@ def test_async_recompute_blocks_not_cached_when_invalid( ...@@ -386,6 +399,19 @@ def test_async_recompute_blocks_not_cached_when_invalid(
f"Block {invalid_block_id} hash should be None but is {block.block_hash}" f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
) )
# Verify connector prefix cache stats:
# - queries = num_prompt_tokens (total tokens not in local cache)
# - hits = num_external_computed_tokens (tokens loaded externally)
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert engine_outputs.scheduler_stats is not None
stats = engine_outputs.scheduler_stats
assert stats.connector_prefix_cache_stats is not None
conn_stats = stats.connector_prefix_cache_stats
assert conn_stats.requests == 1
assert conn_stats.queries == num_prompt_tokens
assert conn_stats.hits == num_external_computed_tokens
# now simulate async transfer completing # now simulate async transfer completing
model_runner_output_2 = create_model_runner_output( model_runner_output_2 = create_model_runner_output(
reqs=[], reqs=[],
......
...@@ -586,6 +586,7 @@ class Scheduler(SchedulerInterface): ...@@ -586,6 +586,7 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens = 0 num_external_computed_tokens = 0
load_kv_async = False load_kv_async = False
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
# Get already-cached tokens. # Get already-cached tokens.
if request.num_computed_tokens == 0: if request.num_computed_tokens == 0:
...@@ -613,6 +614,11 @@ class Scheduler(SchedulerInterface): ...@@ -613,6 +614,11 @@ class Scheduler(SchedulerInterface):
request.num_external_computed_tokens = ext_tokens request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens
connector_prefix_cache_queries = (
request.num_tokens - num_new_local_computed_tokens
)
connector_prefix_cache_hits = num_external_computed_tokens
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = ( num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens num_new_local_computed_tokens + num_external_computed_tokens
...@@ -728,6 +734,15 @@ class Scheduler(SchedulerInterface): ...@@ -728,6 +734,15 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.get_blocks(request_id), self.kv_cache_manager.get_blocks(request_id),
num_external_computed_tokens, num_external_computed_tokens,
) )
if (
self.connector_prefix_cache_stats is not None
and connector_prefix_cache_queries != 0
):
self.connector_prefix_cache_stats.record(
num_tokens=connector_prefix_cache_queries,
num_hits=connector_prefix_cache_hits,
preempted=request.num_preemptions > 0,
)
# Request was already popped from self.waiting # Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None. # unless it was re-added above due to new_blocks being None.
...@@ -739,8 +754,6 @@ class Scheduler(SchedulerInterface): ...@@ -739,8 +754,6 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue continue
self._update_connector_prefix_cache_stats(request)
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event( request.record_event(
...@@ -1805,7 +1818,10 @@ class Scheduler(SchedulerInterface): ...@@ -1805,7 +1818,10 @@ class Scheduler(SchedulerInterface):
return None return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
assert prefix_cache_stats is not None assert prefix_cache_stats is not None
connector_prefix_cache_stats = self._make_connector_prefix_cache_stats() connector_prefix_cache_stats: PrefixCacheStats | None = None
if self.connector_prefix_cache_stats is not None:
connector_prefix_cache_stats = self.connector_prefix_cache_stats
self.connector_prefix_cache_stats = PrefixCacheStats()
eviction_events = ( eviction_events = (
self.kv_metrics_collector.drain_events() self.kv_metrics_collector.drain_events()
if self.kv_metrics_collector is not None if self.kv_metrics_collector is not None
...@@ -1866,23 +1882,6 @@ class Scheduler(SchedulerInterface): ...@@ -1866,23 +1882,6 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods # KV Connector Related Methods
######################################################################## ########################################################################
def _update_connector_prefix_cache_stats(self, request: Request) -> None:
if self.connector_prefix_cache_stats is None:
return
self.connector_prefix_cache_stats.record(
num_tokens=request.num_tokens,
num_hits=request.num_external_computed_tokens,
preempted=request.num_preemptions > 0,
)
def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None:
if self.connector_prefix_cache_stats is None:
return None
stats = self.connector_prefix_cache_stats
self.connector_prefix_cache_stats = PrefixCacheStats()
return stats
def get_kv_connector(self) -> KVConnectorBase_V1 | None: def get_kv_connector(self) -> KVConnectorBase_V1 | None:
return self.connector return self.connector
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment