Unverified Commit ad430a67 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Metrics] Log multi-modal cache stats and fix reset (#26285)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6f0f570c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import LLM
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.v1.metrics.reader import Counter, Metric
from ..openai.test_vision import TEST_IMAGE_ASSETS
def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]:
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_url},
},
],
}
]
def _get_counter_value(metrics: list[Metric], name: str):
metric = next(m for m in metrics if m.name == name)
assert isinstance(metric, Counter)
return metric.value
def _get_mm_cache_stats(metrics: list[Metric]):
mm_cache_queries = _get_counter_value(metrics, "vllm:mm_cache_queries")
mm_cache_hits = _get_counter_value(metrics, "vllm:mm_cache_hits")
return mm_cache_queries, mm_cache_hits
@pytest.mark.parametrize("image_urls", [TEST_IMAGE_ASSETS[:2]], indirect=True)
@pytest.mark.parametrize("mm_processor_cache_type", ["lru", "shm"])
def test_mm_cache_stats(
num_gpus_available,
image_urls,
mm_processor_cache_type,
):
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
mm_processor_cache_type=mm_processor_cache_type,
disable_log_stats=False,
limit_mm_per_prompt={"image": 2},
)
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (1, 0)
llm.chat(_make_messages(image_urls[1]))
assert _get_mm_cache_stats(llm.get_metrics()) == (2, 0)
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (3, 1)
# NOTE: This only resets hit rate stats in CachingMetrics
# The raw queries and hits counts remain unaffected
llm.reset_mm_cache()
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (4, 1)
llm.chat(_make_messages(image_urls[1]))
assert _get_mm_cache_stats(llm.get_metrics()) == (5, 1)
......@@ -18,10 +18,18 @@ from vllm import version
from ...utils import RemoteOpenAIServer
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MODELS = {
"text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct",
}
PREV_MINOR_VERSION = version._prev_minor_version()
@pytest.fixture(scope="module", params=list(MODELS.keys()))
def model_key(request):
yield request.param
@pytest.fixture(scope="module")
def default_server_args():
return [
......@@ -45,11 +53,12 @@ def default_server_args():
f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}",
],
)
def server(default_server_args, request):
def server(model_key, default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
model_name = MODELS[model_key]
with RemoteOpenAIServer(model_name, default_server_args) as remote_server:
yield remote_server
......@@ -60,64 +69,70 @@ async def client(server):
_PROMPT = "Hello my name is Robert and I love magic"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"]
_NUM_REQUESTS = 10
_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT)
_NUM_GENERATION_TOKENS_PER_REQUEST = 10
# {metric_family: [(suffix, expected_value)]}
EXPECTED_VALUES = {
"vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
"vllm:time_per_output_token_seconds": [
("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))
],
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prompt_tokens": [
("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS),
],
"vllm:request_generation_tokens": [
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS),
],
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
"vllm:request_params_max_tokens": [
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS),
],
"vllm:iteration_tokens_total": [
(
"_sum",
_NUM_REQUESTS
* (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST),
),
("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
],
"vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:generation_tokens": [
("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)
],
"vllm:request_success": [("_total", _NUM_REQUESTS)],
}
_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int):
num_prompt_tokens = len(prompt_ids)
# {metric_family: [(suffix, expected_value)]}
return {
"vllm:time_to_first_token_seconds": [("_count", num_requests)],
"vllm:time_per_output_token_seconds": [
("_count", num_requests * (max_tokens - 1))
],
"vllm:e2e_request_latency_seconds": [("_count", num_requests)],
"vllm:request_queue_time_seconds": [("_count", num_requests)],
"vllm:request_inference_time_seconds": [("_count", num_requests)],
"vllm:request_prefill_time_seconds": [("_count", num_requests)],
"vllm:request_decode_time_seconds": [("_count", num_requests)],
"vllm:request_prompt_tokens": [
("_sum", num_requests * num_prompt_tokens),
("_count", num_requests),
],
"vllm:request_generation_tokens": [
("_sum", num_requests * max_tokens),
("_count", num_requests),
],
"vllm:request_params_n": [("_count", num_requests)],
"vllm:request_params_max_tokens": [
("_sum", num_requests * max_tokens),
("_count", num_requests),
],
"vllm:iteration_tokens_total": [
(
"_sum",
num_requests * (num_prompt_tokens + max_tokens),
),
("_count", num_requests * max_tokens),
],
"vllm:prompt_tokens": [("_total", num_requests * num_prompt_tokens)],
"vllm:generation_tokens": [("_total", num_requests * max_tokens)],
"vllm:request_success": [("_total", num_requests)],
}
@pytest.mark.asyncio
async def test_metrics_counts(
server: RemoteOpenAIServer,
client: openai.AsyncClient,
model_key: str,
):
for _ in range(_NUM_REQUESTS):
if model_key == "multimodal":
pytest.skip("Unnecessary test")
model_name = MODELS[model_key]
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_ids = tokenizer.encode(_PROMPT)
num_requests = 10
max_tokens = 10
for _ in range(num_requests):
# sending a request triggers the metrics to be logged.
await client.completions.create(
model=MODEL_NAME,
prompt=_TOKENIZED_PROMPT,
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST,
model=model_name,
prompt=prompt_ids,
max_tokens=max_tokens,
)
response = requests.get(server.url_for("metrics"))
......@@ -125,8 +140,9 @@ async def test_metrics_counts(
assert response.status_code == HTTPStatus.OK
# Loop over all expected metric_families
for metric_family, suffix_values_list in EXPECTED_VALUES.items():
if (metric_family not in EXPECTED_METRICS_V1) or (
expected_values = _get_expected_values(num_requests, prompt_ids, max_tokens)
for metric_family, suffix_values_list in expected_values.items():
if metric_family not in EXPECTED_METRICS_V1 or (
not server.show_hidden_metrics
and metric_family in HIDDEN_DEPRECATED_METRICS
):
......@@ -217,6 +233,11 @@ EXPECTED_METRICS_V1 = [
"vllm:request_decode_time_seconds_count",
]
EXPECTED_METRICS_MM = [
"vllm:mm_cache_queries",
"vllm:mm_cache_hits",
]
HIDDEN_DEPRECATED_METRICS: list[str] = [
"vllm:gpu_cache_usage_perc",
"vllm:gpu_prefix_cache_queries",
......@@ -231,19 +252,43 @@ HIDDEN_DEPRECATED_METRICS: list[str] = [
async def test_metrics_exist(
server: RemoteOpenAIServer,
client: openai.AsyncClient,
model_key: str,
):
model_name = MODELS[model_key]
# sending a request triggers the metrics to be logged.
await client.completions.create(
model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0,
)
if model_key == "text":
await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0,
)
else:
await client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": _IMAGE_URL}},
{"type": "text", "text": "What's in this image?"},
],
}
],
max_tokens=5,
temperature=0.0,
)
response = requests.get(server.url_for("metrics"))
assert response.status_code == HTTPStatus.OK
for metric in EXPECTED_METRICS_V1:
expected_metrics = EXPECTED_METRICS_V1
if model_key == "multimodal":
# NOTE: Don't use in-place assignment
expected_metrics = expected_metrics + EXPECTED_METRICS_MM
for metric in expected_metrics:
if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics:
continue
assert metric in response.text
......@@ -253,9 +298,14 @@ async def test_metrics_exist(
async def test_abort_metrics_reset(
server: RemoteOpenAIServer,
client: openai.AsyncClient,
model_key: str,
):
model_name = MODELS[model_key]
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_ids = tokenizer.encode(_PROMPT)
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
server
server,
)
# Expect no running requests or kvcache usage
......@@ -268,8 +318,8 @@ async def test_abort_metrics_reset(
for _ in range(3):
task = asyncio.create_task(
client.completions.create(
model=MODEL_NAME,
prompt=_TOKENIZED_PROMPT,
model=model_name,
prompt=prompt_ids,
max_tokens=100, # Long generation to give time to abort
temperature=0.0,
)
......@@ -281,7 +331,7 @@ async def test_abort_metrics_reset(
# Check that we have running requests
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
server
server,
)
# Expect running requests and kvcache usage
......
......@@ -20,7 +20,6 @@ from vllm.v1.core.kv_cache_utils import (
BlockHash,
FreeKVCacheBlockQueue,
KVCacheBlock,
PrefixCachingMetrics,
estimate_max_model_len,
generate_block_hash_extra_keys,
generate_scheduler_kv_cache_config,
......@@ -42,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
from vllm.v1.request import Request
pytestmark = pytest.mark.cpu_test
......@@ -536,7 +535,7 @@ def test_metrics():
"""
Test the prefix caching metrics.
"""
metrics = PrefixCachingMetrics(max_recent_requests=5)
metrics = CachingMetrics(max_recent_requests=5)
assert metrics.hit_rate == 0.0
metrics.observe(_stats(1, 20, 9))
......@@ -568,7 +567,7 @@ def test_metrics_empty_stats():
"""
Test the prefix caching metrics with empty stats.
"""
metrics = PrefixCachingMetrics(max_recent_requests=5)
metrics = CachingMetrics(max_recent_requests=5)
metrics.observe(_stats(0, 0, 0))
metrics.observe(_stats(1, 20, 9))
metrics.observe(_stats(0, 0, 0))
......
......@@ -17,7 +17,7 @@ from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, SchedulerStats
DP_SIZE = int(os.getenv("DP_SIZE", 2))
......@@ -93,6 +93,7 @@ async def test_load(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0,
):
if iteration_stats:
......
......@@ -354,6 +354,10 @@ class LLM:
else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def reset_mm_cache(self) -> None:
self.processor.clear_mm_cache()
self.llm_engine.reset_mm_cache()
def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
self.default_sampling_params = self.model_config.get_diff_sampling_param()
......
......@@ -274,6 +274,10 @@ class OpenAIServing:
self.model_config = self.models.model_config
self.max_model_len = self.model_config.max_model_len
async def reset_mm_cache(self) -> None:
self.processor.clear_mm_cache()
await self.engine_client.reset_mm_cache()
async def beam_search(
self,
prompt: PromptType,
......
......@@ -169,6 +169,10 @@ class ExecutorBase(ABC):
assert s == sets[0], "All workers should have the same LORAs."
return sets[0]
def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache in each worker."""
self.collective_rpc("reset_mm_cache")
def start_profile(self) -> None:
self.collective_rpc("start_profile")
......
......@@ -12,11 +12,8 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
......@@ -30,16 +27,13 @@ class UniProcExecutor(ExecutorBase):
"""Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
distributed_init_method, rank, local_rank = self._distributed_args()
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock()
is_driver_worker=True,
shared_worker_lock=Lock(),
)
self.async_output_thread: Optional[ThreadPoolExecutor] = None
......@@ -74,8 +68,6 @@ class UniProcExecutor(ExecutorBase):
) -> list[Any]:
if kwargs is None:
kwargs = {}
if self.mm_receiver_cache is not None and method == "execute_model":
get_and_update_mm_cache(self.mm_receiver_cache, args)
if not non_block:
return [run_method(self.driver_worker, method, args, kwargs)]
......
......@@ -19,6 +19,7 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import (
DecoderOnlyInputs,
......@@ -56,6 +57,8 @@ class InputPreprocessor:
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError(
......@@ -664,14 +667,13 @@ class InputPreprocessor:
return self._build_decoder_only_llm_inputs(prompt_comps)
def preprocess(
def _preprocess(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder.
......@@ -694,6 +696,40 @@ class InputPreprocessor:
mm_uuids=mm_uuids,
)
def clear_cache(self) -> None:
def preprocess(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
res = self._preprocess(
prompt,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
if self.mm_processor_cache and self.mm_cache_stats is not None:
delta = self.mm_processor_cache.make_stats(delta=True)
self.mm_cache_stats.requests += 1
self.mm_cache_stats.queries += delta.total
self.mm_cache_stats.hits += delta.hits
return res
def stat_mm_cache(self) -> Optional[MultiModalCacheStats]:
mm_cache_stats = self.mm_cache_stats
if mm_cache_stats is None:
return None
self.mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def clear_mm_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
if self.mm_cache_stats is not None:
self.mm_cache_stats.reset = True
......@@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.shm_object_storage import (
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, MiB_bytes
from vllm.utils.cache import LRUCache
from vllm.utils.cache import CacheInfo, LRUCache
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
from .inputs import (
......@@ -302,6 +302,16 @@ class BaseMultiModalProcessorCache(
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
@abstractmethod
def make_stats(self, *, delta: bool = False) -> CacheInfo:
"""
Get (and reset) the multi-modal cache stats.
Returns:
The current multi-modal caching stats.
"""
raise NotImplementedError
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
"""
......@@ -347,6 +357,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
def clear_cache(self) -> None:
self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
"""
......@@ -397,6 +411,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
def clear_cache(self) -> None:
self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
"""
......@@ -430,6 +448,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
# cache (prompt_updates, modality) for P0 only
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
def _stat(self, *, delta: bool = False) -> CacheInfo:
info = CacheInfo(hits=self._hits, total=self._total)
if delta:
info_delta = info - self._last_info
self._last_info = info
info = info_delta
return info
@override
def is_cached_item(self, mm_hash: str) -> bool:
return self._shm_cache.is_cached(mm_hash)
......@@ -441,12 +473,17 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if self._shm_cache.is_cached(mm_hash):
self._hits += 1
self._total += 1
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates, modality = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id, modality), prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._total += 1
try:
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
# Try to remove dangling items if p0 cache is too large.
......@@ -469,6 +506,14 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
self._shm_cache.clear()
self._p0_cache.clear()
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._stat(delta=delta)
def remove_dangling_items(self) -> None:
"""Remove items that are no longer in the shared memory cache."""
cached_hashes = self._shm_cache.key_index.keys()
......
......@@ -4,7 +4,7 @@
import copy
import os
from collections import defaultdict, deque
from collections import defaultdict
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, NewType, Optional, Union
......@@ -23,7 +23,6 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
# BlockHash represents the hash of a single KV-cache block used for
......@@ -101,78 +100,6 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]):
NONE_HASH = BlockHash(hash_fn(hash_seed))
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the max recent N requests.
Args:
max_recent_requests: The number of the max recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, max_recent_requests: int = 1000):
self.max_recent_requests = max_recent_requests
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[tuple[int, int, int]] = deque()
def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `max_recent_requests` requests, the oldest set
of requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# DO NOT appending empty stats to avoid helpful info get kicked out
# due to sliding window.
if stats.requests == 0:
return
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats until number of requests does not exceed
# the limit.
# NOTE: We preserve the latest added stats regardless.
while (
len(self.query_queue) > 1
and self.aggregated_requests > self.max_recent_requests
):
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass
class KVCacheBlock:
"""KV-cache block metadata."""
......
......@@ -463,6 +463,7 @@ class AsyncLLM(EngineClient):
output_processor = self.output_processor
log_stats = self.log_stats
logger_manager = self.logger_manager
processor = self.processor
async def output_handler():
try:
......@@ -511,6 +512,7 @@ class AsyncLLM(EngineClient):
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=processor.stat_mm_cache(),
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
......@@ -660,7 +662,7 @@ class AsyncLLM(EngineClient):
await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None:
self.processor.clear_cache()
self.processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, device: Optional[Device] = None) -> None:
......
......@@ -319,7 +319,7 @@ class EngineCore:
)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
) # type: ignore
)
return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)
......@@ -400,16 +400,19 @@ class EngineCore:
def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
# re-sync the internal caches (P0 sender, P1 receiver)
if self.scheduler.has_unfinished_requests():
logger.warning(
"Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches."
)
# The cache either exists in EngineCore or WorkerWrapperBase
if self.mm_receiver_cache is not None:
self.mm_receiver_cache.clear_cache()
self.model_executor.reset_mm_cache()
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
......
......@@ -306,9 +306,11 @@ class LLMEngine:
# 4) Record stats
if self.logger_manager is not None:
assert outputs.scheduler_stats is not None
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()
......@@ -321,7 +323,7 @@ class LLMEngine:
self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.clear_cache()
self.processor.clear_mm_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None):
......
......@@ -21,6 +21,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer,
......@@ -573,5 +574,8 @@ class Processor:
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def clear_cache(self) -> None:
self.input_preprocessor.clear_cache()
def stat_mm_cache(self) -> Optional[MultiModalCacheStats]:
return self.input_preprocessor.stat_mm_cache()
def clear_mm_cache(self) -> None:
self.input_preprocessor.clear_mm_cache()
......@@ -33,8 +33,6 @@ from vllm.distributed.parallel_state import (
get_tp_group,
)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (
_maybe_force_spawn,
decorate_logs,
......@@ -46,7 +44,6 @@ from vllm.utils import (
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
......@@ -422,6 +419,7 @@ class WorkerProc:
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker,
"shared_worker_lock": shared_worker_lock,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper
......@@ -445,11 +443,6 @@ class WorkerProc:
)
self.async_output_copy_thread.start()
# Initialize multimodal receiver cache if needed
self.mm_receiver_cache = worker_receiver_cache_from_config(
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock
)
# Initialize device
self.worker.init_device()
......@@ -692,12 +685,7 @@ class WorkerProc:
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
# retrieve from shm cache if available
if (
self.mm_receiver_cache is not None
and func.__name__ == "execute_model"
):
get_and_update_mm_cache(self.mm_receiver_cache, args)
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.multimodal.cache import ShmObjectStoreReceiverCache
from vllm.v1.core.sched.output import SchedulerOutput
def get_and_update_mm_cache(
receiver_cache: ShmObjectStoreReceiverCache,
args: tuple[SchedulerOutput],
) -> None:
"""
For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory
cache as needed.
Args:
receiver_cache: The receiver cache to update.
args: According to the collective_rpc call of execute_model method in
executor, args is a tuple of only one SchedulerOutput element.
"""
scheduler_output = args[0]
for request_data in scheduler_output.scheduled_new_reqs:
request_data.mm_features = receiver_cache.get_and_update_features(
request_data.mm_features
)
......@@ -11,10 +11,14 @@ import prometheus_client
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.metrics.stats import (
CachingMetrics,
IterationStats,
MultiModalCacheStats,
SchedulerStats,
)
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
......@@ -38,6 +42,7 @@ class StatLoggerBase(ABC):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0,
): ...
......@@ -53,10 +58,15 @@ class LoggingStatLogger(StatLoggerBase):
self.engine_index = engine_index
self.vllm_config = vllm_config
self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset.
self.last_mm_cache_stats: Optional[MultiModalCacheStats] = None
# Caching metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.prefix_caching_metrics = CachingMetrics()
self.mm_caching_metrics = CachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging()
kv_tranfer_config = self.vllm_config.kv_transfer_config
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
......@@ -86,6 +96,7 @@ class LoggingStatLogger(StatLoggerBase):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0,
):
"""Log Stats to standard output."""
......@@ -101,6 +112,11 @@ class LoggingStatLogger(StatLoggerBase):
self.kv_connector_logging.observe(kv_connector_stats)
self.last_scheduler_stats = scheduler_stats
if mm_cache_stats:
self.mm_caching_metrics.observe(mm_cache_stats)
self.last_mm_cache_stats = mm_cache_stats
def log(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
......@@ -125,21 +141,32 @@ class LoggingStatLogger(StatLoggerBase):
self.last_prompt_throughput = prompt_throughput
# Format and print output.
log_fn(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
log_parts = [
"Avg prompt throughput: %.1f tokens/s",
"Avg generation throughput: %.1f tokens/s",
"Running: %d reqs",
"Waiting: %d reqs",
"GPU KV cache usage: %.1f%%",
"Prefix cache hit rate: %.1f%%",
self.engine_index,
]
log_args = [
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
]
if self.last_mm_cache_stats:
log_parts.append("MM cache hit rate: %.1f%%")
log_args.append(self.mm_caching_metrics.hit_rate * 100)
log_fn(
"Engine %03d: " + ", ".join(log_parts),
self.engine_index,
*log_args,
)
self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_connector_logging.log(log_fn=log_fn)
......@@ -288,6 +315,32 @@ class PrometheusStatLogger(StatLoggerBase):
counter_prefix_cache_hits, engine_indexes, model_name
)
#
# Multi-modal cache
#
counter_mm_cache_queries = self._counter_cls(
name="vllm:mm_cache_queries",
documentation=(
"Multi-modal cache queries, in terms of number of queried items."
),
labelnames=labelnames,
)
self.counter_mm_cache_queries = make_per_engine(
counter_mm_cache_queries, engine_indexes, model_name
)
counter_mm_cache_hits = self._counter_cls(
name="vllm:mm_cache_hits",
documentation=(
"Multi-modal cache hits, in terms of number of cached items."
),
labelnames=labelnames,
)
self.counter_mm_cache_hits = make_per_engine(
counter_mm_cache_hits, engine_indexes, model_name
)
#
# Counters
#
......@@ -657,6 +710,7 @@ class PrometheusStatLogger(StatLoggerBase):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0,
):
"""Log to prometheus."""
......@@ -694,6 +748,10 @@ class PrometheusStatLogger(StatLoggerBase):
scheduler_stats.spec_decoding_stats, engine_idx
)
if mm_cache_stats is not None:
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
if iteration_stats is None:
return
......@@ -871,6 +929,7 @@ class StatLoggerManager:
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: Optional[int] = None,
):
if engine_idx is None:
......@@ -878,9 +937,19 @@ class StatLoggerManager:
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
logger.record(
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx)
self.prometheus_logger.record(
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections import deque
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
......@@ -13,24 +14,122 @@ if TYPE_CHECKING:
@dataclass
class PrefixCacheStats:
"""Stores prefix cache hit statistics."""
class BaseCacheStats:
"""Stores cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False
# The number of new requests in this update.
"""Whether the cache was reset."""
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of tokens that were queried from the cache.
"""The number of requests in this update."""
queries: int = 0
# The number of hits in these requests.
"""The number of queries in these requests."""
hits: int = 0
# The number of previously preempted requests in this update.
"""The number of hits in these requests."""
class CachingMetrics:
"""Metrics for caching with a hit rate of the most recent N requests.
Args:
interval: The number of the most recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, max_recent_requests: int = 1000) -> None:
super().__init__()
self.max_recent_requests = max_recent_requests
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue = deque[tuple[int, int, int]]()
def observe(self, stats: BaseCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `max_recent_requests` requests, the oldest set
of requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# DO NOT appending empty stats to avoid helpful info get kicked out
# due to sliding window.
if stats.requests == 0:
return
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats until number of requests does not exceed
# the limit.
# NOTE: We preserve the latest added stats regardless.
while (
len(self.query_queue) > 1
and self.aggregated_requests > self.max_recent_requests
):
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass
class PrefixCacheStats(BaseCacheStats):
"""
Stores prefix cache hit statistics.
- `reset`: Whether `reset_prefix_cache` was invoked.
- `queries`: Refers to the number of tokens that were queried.
"""
preempted_requests: int = 0
# The `queries` number for preempted requests.
"""The number of previously preempted requests in this update."""
preempted_queries: int = 0
# The `hits` number for preempted requests.
"""The `queries` number for preempted requests."""
preempted_hits: int = 0
"""The `hits` number for preempted requests."""
@dataclass
class MultiModalCacheStats(BaseCacheStats):
"""
Stores multi-modal cache hit statistics.
- `reset`: Whether `reset_mm_cache` was invoked.
- `queries`: Refers to the number of multi-modal data items
that were queried.
"""
@dataclass
......
......@@ -508,6 +508,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
)
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int):
if self.uses_mrope:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment