Unverified Commit 41c5dd45 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[V1][Metrics] Add GPU prefix cache hit rate % gauge (#12592)

parent fc6485d2
...@@ -203,6 +203,8 @@ EXPECTED_METRICS_V1 = [ ...@@ -203,6 +203,8 @@ EXPECTED_METRICS_V1 = [
"vllm:num_requests_running", "vllm:num_requests_running",
"vllm:num_requests_waiting", "vllm:num_requests_waiting",
"vllm:gpu_cache_usage_perc", "vllm:gpu_cache_usage_perc",
"vllm:gpu_prefix_cache_queries",
"vllm:gpu_prefix_cache_hits",
"vllm:prompt_tokens_total", "vllm:prompt_tokens_total",
"vllm:generation_tokens_total", "vllm:generation_tokens_total",
"vllm:request_success_total", "vllm:request_success_total",
......
...@@ -5,10 +5,11 @@ import pytest ...@@ -5,10 +5,11 @@ import pytest
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, KVCacheBlock, PrefixCachingMetrics,
generate_block_hash_extra_keys, generate_block_hash_extra_keys,
hash_block_tokens, hash_block_tokens,
hash_request_tokens) hash_request_tokens)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -277,3 +278,39 @@ def test_hash_request_tokens_no_mm_inputs(): ...@@ -277,3 +278,39 @@ def test_hash_request_tokens_no_mm_inputs():
assert block_hashes[0].extra_keys is None assert block_hashes[0].extra_keys is None
assert block_hashes[1].token_ids == (3, 4, 5) assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys is None assert block_hashes[1].extra_keys is None
def test_metrics():
"""
Test the prefix caching metrics.
"""
def stats(requests, queries, hits):
return PrefixCacheStats(requests=requests, queries=queries, hits=hits)
metrics = PrefixCachingMetrics(interval=5)
assert metrics.hit_rate == 0.0
metrics.observe(stats(1, 20, 9))
# 9 / 20 = 0.45
assert metrics.hit_rate == 0.45
metrics.observe(stats(4, 80, 16))
# 25 / 100 = 0.25
assert metrics.hit_rate == 0.25
metrics.observe(stats(1, 10, 2))
# Remove (20, 9) and add (10, 2): 18 / 90 = 0.2
assert metrics.aggregated_requests == 5
assert metrics.aggregated_query_total == 90
assert metrics.aggregated_query_hit == 18
assert metrics.hit_rate == 0.2
metrics.reset()
assert metrics.hit_rate == 0.0
assert metrics.aggregated_requests == 0
assert metrics.aggregated_query_total == 0
assert metrics.aggregated_query_hit == 0
assert not metrics.query_queue
...@@ -10,6 +10,7 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, ...@@ -10,6 +10,7 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
generate_block_hash_extra_keys, generate_block_hash_extra_keys,
hash_block_tokens, hash_block_tokens,
hash_request_tokens) hash_request_tokens)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -78,11 +79,28 @@ class KVCacheManager: ...@@ -78,11 +79,28 @@ class KVCacheManager:
self.req_to_block_hashes: DefaultDict[ self.req_to_block_hashes: DefaultDict[
str, List[BlockHashType]] = defaultdict(list) str, List[BlockHashType]] = defaultdict(list)
self.prefix_cache_stats = PrefixCacheStats()
@property @property
def usage(self) -> float: def usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return 1.0 - (self.free_block_queue.num_free_blocks / return 1.0 - (self.free_block_queue.num_free_blocks /
self.num_gpu_blocks) self.num_gpu_blocks)
def make_prefix_cache_stats(self) -> PrefixCacheStats:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats.
"""
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def get_computed_blocks( def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]: self, request: Request) -> Tuple[List[KVCacheBlock], int]:
"""Get the computed (cached) blocks for the request. """Get the computed (cached) blocks for the request.
...@@ -118,6 +136,10 @@ class KVCacheManager: ...@@ -118,6 +136,10 @@ class KVCacheManager:
else: else:
break break
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
# NOTE(woosuk): Since incomplete blocks are not eligible for # NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of # sharing, `num_computed_tokens` is always a multiple of
# `block_size`. # `block_size`.
...@@ -280,6 +302,8 @@ class KVCacheManager: ...@@ -280,6 +302,8 @@ class KVCacheManager:
for block in self.block_pool: for block in self.block_pool:
block.reset_hash() block.reset_hash()
self.prefix_cache_stats.reset = True
logger.info("Successfully reset prefix cache") logger.info("Successfully reset prefix cache")
return True return True
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple from typing import Any, List, NamedTuple, Optional, Tuple
...@@ -8,6 +9,7 @@ from vllm.config import VllmConfig ...@@ -8,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
KVCacheTensor) KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,6 +30,68 @@ class BlockHashType(NamedTuple): ...@@ -28,6 +30,68 @@ class BlockHashType(NamedTuple):
extra_keys: Optional[Any] = None extra_keys: Optional[Any] = None
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the most recent N requests.
Args:
interval: The number of the most recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, interval: int = 1000):
self.interval = interval
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[Tuple[int, int, int]] = deque()
def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `interval` requests, the oldest set of
requestsare removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats if the number of requests exceeds.
if self.aggregated_requests > self.interval:
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass @dataclass
class KVCacheBlock: class KVCacheBlock:
"""KV-cache block metadata.""" """KV-cache block metadata."""
......
...@@ -593,4 +593,5 @@ class Scheduler: ...@@ -593,4 +593,5 @@ class Scheduler:
num_running_reqs=len(self.running), num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage, gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
) )
...@@ -9,6 +9,7 @@ import prometheus_client ...@@ -9,6 +9,7 @@ import prometheus_client
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
...@@ -37,6 +38,9 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -37,6 +38,9 @@ class LoggingStatLogger(StatLoggerBase):
self.num_prompt_tokens: List[int] = [] self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []
# Prefix cache metrics. TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
def _local_interval_elapsed(self, now: float) -> bool: def _local_interval_elapsed(self, now: float) -> bool:
# Log every _LOCAL_LOGGING_INTERVAL_SEC. # Log every _LOCAL_LOGGING_INTERVAL_SEC.
elapsed_time = now - self.last_log_time elapsed_time = now - self.last_log_time
...@@ -58,6 +62,8 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -58,6 +62,8 @@ class LoggingStatLogger(StatLoggerBase):
self._track_iteration_stats(iteration_stats) self._track_iteration_stats(iteration_stats)
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
now = time.monotonic() now = time.monotonic()
if not self._local_interval_elapsed(now): if not self._local_interval_elapsed(now):
return return
...@@ -72,13 +78,15 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -72,13 +78,15 @@ class LoggingStatLogger(StatLoggerBase):
logger.info( logger.info(
"Avg prompt throughput: %.1f tokens/s, " "Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs " "Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%.", "GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
prompt_throughput, prompt_throughput,
generation_throughput, generation_throughput,
scheduler_stats.num_running_reqs, scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs, scheduler_stats.num_waiting_reqs,
scheduler_stats.gpu_cache_usage * 100, scheduler_stats.gpu_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
) )
...@@ -107,6 +115,18 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -107,6 +115,18 @@ class PrometheusStatLogger(StatLoggerBase):
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_queries = prometheus_client.Counter(
name="vllm:gpu_prefix_cache_queries",
documentation=
"GPU prefix cache queries, in terms of number of queried blocks.",
labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_hits = prometheus_client.Counter(
name="vllm:gpu_prefix_cache_hits",
documentation=
"GPU prefix cache hits, in terms of number of cached blocks.",
labelnames=labelnames).labels(*labelvalues)
self.counter_prompt_tokens = prometheus_client.Counter( self.counter_prompt_tokens = prometheus_client.Counter(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
...@@ -170,6 +190,11 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -170,6 +190,11 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
self.counter_gpu_prefix_cache_queries.inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc( self.counter_generation_tokens.inc(
iteration_stats.num_generation_tokens) iteration_stats.num_generation_tokens)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -9,6 +9,20 @@ if TYPE_CHECKING: ...@@ -9,6 +9,20 @@ if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreOutput, FinishReason
@dataclass
class PrefixCacheStats:
"""Stores prefix cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False
# The number of requests in this update.
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of blocks that were queried from the cache.
queries: int = 0
# The number of hits in these requests.
hits: int = 0
@dataclass @dataclass
class SchedulerStats: class SchedulerStats:
"""Stats associated with the scheduler.""" """Stats associated with the scheduler."""
...@@ -17,7 +31,9 @@ class SchedulerStats: ...@@ -17,7 +31,9 @@ class SchedulerStats:
num_waiting_reqs: int = 0 num_waiting_reqs: int = 0
gpu_cache_usage: float = 0.0 gpu_cache_usage: float = 0.0
# gpu_prefix_cache_hit_rate: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment