Unverified Commit 3ac50b47 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[MISC] Add prefix cache hit rate to metrics (#7606)

parent df845b2b
......@@ -682,6 +682,32 @@ class TestPrefixCachingBlockAllocator:
assert new_block[0].block_id == last_block_id
# Test case for cache mertics
@staticmethod
def test_metric():
block_size = 16
allocator = PrefixCachingBlockAllocator(num_blocks=4,
block_size=block_size)
# Test when no query (0/0)
assert allocator.get_prefix_cache_hit_rate() == 0.0
token_ids = list(range(block_size))
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 0/1 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.0
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 1/2 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.5
# Test more than one block
for _ in range(2, 1005):
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
assert allocator.get_prefix_cache_hit_rate() > 0.99
@staticmethod
def create_immutable_chain(
block_size: int,
......
......@@ -34,6 +34,9 @@ def test_block_allocator(
assert (first_block == second_block)
assert (second_block.ref_count == 2)
# Check metric: 1 hit of 2 queries
assert block_allocator.get_prefix_cache_hit_rate() == 0.5
# Free the first_block and confirm that the ref_count is correctly
# decremented on the second block
block_allocator.free(first_block)
......@@ -48,6 +51,10 @@ def test_block_allocator(
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)
# Allocate one more time to get 3/4 hit rate for easy checking
block_allocator.allocate(block_hash, 0)
assert block_allocator.get_prefix_cache_hit_rate() == 0.75
@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
......
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
from vllm.core.block.interfaces import Block, BlockAllocator
......@@ -282,6 +283,58 @@ class BlockList:
return self._block_ids
@dataclass
class CacheMetricData:
"""A utility dataclass to maintain cache metric.
To avoid overflow, we maintain the hit rate in block granularity, so that
we can maintain a single hit rate for n_completed_block x block_size,
and calculate the real time hit rate by the following:
BS = The number of queries per block.
nB = The number of completed blocks.
HR = hit rate of (nB x BS) queries.
Q = current number of queries (< BS).
H = current number of hits (< BS).
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
"""
num_completed_blocks: int = 0
completed_block_cache_hit_rate: float = 0.0
num_incompleted_block_queries: int = 0
num_incompleted_block_hit: int = 0
block_size: int = 1000
def query(self, hit: bool):
self.num_incompleted_block_queries += 1
self.num_incompleted_block_hit += 1 if hit else 0
# When a block is completed, update the cache hit rate
# and reset the incomplete numbers.
if self.num_incompleted_block_queries == self.block_size:
hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
self.completed_block_cache_hit_rate = (
self.completed_block_cache_hit_rate * self.num_completed_blocks
+ hit_rate) / (self.num_completed_blocks + 1)
self.num_incompleted_block_queries = 0
self.num_incompleted_block_hit = 0
self.num_completed_blocks += 1
def get_hit_rate(self):
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
total_blocks = self.num_completed_blocks + incomplete_ratio
if total_blocks == 0:
return 0.0
completed_block_hit, incompleted_block_hit = 0.0, 0.0
if self.num_completed_blocks > 0:
completed_block_hit = (self.completed_block_cache_hit_rate *
self.num_completed_blocks)
if self.num_incompleted_block_queries > 0:
incompleted_hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
return (completed_block_hit + incompleted_block_hit) / total_blocks
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.
......
......@@ -323,6 +323,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
......
......@@ -186,6 +186,11 @@ class BlockAllocator(ABC):
num_lookahead_slots: int = 0) -> int:
pass
@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
class NoFreeBlocksError(ValueError):
pass
......@@ -278,3 +283,8 @@ class DeviceAwareBlockAllocator(ABC):
There is at most one null block per allocator.
"""
pass
@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
......@@ -341,6 +341,9 @@ class NaiveBlockAllocator(BlockAllocator):
block.block_id = block_id # Assign block_id
def get_prefix_cache_hit_rate(self) -> float:
return -1
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
......
"""Token blocks."""
from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker,
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
......@@ -107,6 +106,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly())
self.metric_data = CacheMetricData()
# Implements Block.Factory.
def _create_block(
self,
......@@ -155,9 +156,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
self.metric_data.query(hit=True)
block.block_id = cached_block_id
self._incr_refcount_cached_block(block)
return block
self.metric_data.query(hit=False)
self._block_pool.free_block(block)
# No cached block => Allocate a new block
......@@ -404,6 +407,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids
def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()
def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
if block.content_hash in self._cached_blocks:
......
......@@ -8,6 +8,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.block.common import CacheMetricData
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
......@@ -60,6 +61,11 @@ class BlockAllocatorBase(ABC):
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass
@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
......@@ -85,6 +91,8 @@ class CachedBlockAllocator(BlockAllocatorBase):
self.default_hash_ctr = count()
self.cache_metric_data = CacheMetricData()
def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
......@@ -105,15 +113,17 @@ class CachedBlockAllocator(BlockAllocatorBase):
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if block_hash is None:
block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor:
assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash
return block
if block_hash not in self.cached_blocks:
if block_hash in self.cached_blocks:
self.cache_metric_data.query(hit=True)
else:
self.cache_metric_data.query(hit=False)
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
......@@ -150,6 +160,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
def get_prefix_cache_hit_rate(self) -> float:
return self.cache_metric_data.get_hit_rate()
class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
......@@ -209,6 +222,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
def get_prefix_cache_hit_rate(self) -> float:
return -1
class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks."""
......@@ -705,3 +721,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
return self.gpu_allocator.get_prefix_cache_hit_rate()
if device == Device.CPU:
return self.cpu_allocator.get_prefix_cache_hit_rate()
raise ValueError(f"Invalid device: {device}")
......@@ -441,6 +441,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)
def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
......
......@@ -2,6 +2,7 @@ from typing import List, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
......@@ -81,3 +82,6 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
......@@ -85,19 +85,21 @@ class LRUEvictor(Evictor):
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
evicted_block_id = next(iter(self.free_table.keys()))
evicted_block, evicted_block_id = None, None
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items():
if evicted_block is None:
evicted_block, evicted_block_id = block, _id
continue
if evicted_block.last_accessed < block.last_accessed:
break
if (evicted_block.last_accessed == block.last_accessed and
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
evicted_block = block
evicted_block_id = _id
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block, evicted_block_id = block, _id
assert evicted_block is not None
assert evicted_block_id is not None
self.free_table.pop(evicted_block_id)
return evicted_block_id, evicted_block.content_hash
......@@ -110,7 +112,6 @@ class LRUEvictor(Evictor):
def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
self.free_table.move_to_end(block_id)
def remove(self, block_id: int):
if block_id not in self.free_table:
......
......@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Tuple
from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class AllocStatus(enum.Enum):
......@@ -116,3 +117,8 @@ class BlockSpaceManager(ABC):
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
......@@ -14,7 +14,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import PyObjectCache
from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__)
......@@ -447,6 +447,9 @@ class Scheduler:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
......
......@@ -47,7 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
from vllm.utils import Counter, Device
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
......@@ -1390,6 +1390,13 @@ class LLMEngine:
for scheduler in self.scheduler)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Prefix Cache Hit Rate. Note that we always use
# the cache hit rate of the first virtual engine.
cpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.CPU)
gpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.GPU)
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
......@@ -1498,6 +1505,9 @@ class LLMEngine:
# KV Cache Usage in %
gpu_cache_usage_sys=gpu_cache_usage_sys,
cpu_cache_usage_sys=cpu_cache_usage_sys,
# Prefix Cache Hit Rate
cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
# Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter,
......
......@@ -71,6 +71,17 @@ class Metrics:
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames,
multiprocess_mode="sum")
# Prefix caching block hit rate
self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
name="vllm:cpu_prefix_cache_hit_rate",
documentation="CPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
name="vllm:gpu_prefix_cache_hit_rate",
documentation="GPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
# Iteration stats
self.counter_num_preemption = self._counter_cls(
......@@ -351,7 +362,13 @@ class LoggingStatLogger(StatLoggerBase):
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
if (stats.cpu_prefix_cache_hit_rate >= 0
or stats.gpu_prefix_cache_hit_rate >= 0):
logger.info(
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
stats.gpu_prefix_cache_hit_rate * 100,
stats.cpu_prefix_cache_hit_rate * 100,
)
if self.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
......@@ -423,6 +440,10 @@ class PrometheusStatLogger(StatLoggerBase):
stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
stats.cpu_prefix_cache_hit_rate)
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
stats.gpu_prefix_cache_hit_rate)
# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
......
......@@ -32,6 +32,9 @@ class Stats:
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Prefix caching block hit rate
cpu_prefix_cache_hit_rate: float
gpu_prefix_cache_hit_rate: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment