Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad430a67
Unverified
Commit
ad430a67
authored
Oct 10, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 10, 2025
Browse files
[Metrics] Log multi-modal cache stats and fix reset (#26285)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
6f0f570c
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
508 additions
and
225 deletions
+508
-225
tests/entrypoints/llm/test_mm_cache_stats.py
tests/entrypoints/llm/test_mm_cache_stats.py
+74
-0
tests/entrypoints/openai/test_metrics.py
tests/entrypoints/openai/test_metrics.py
+115
-65
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+3
-4
tests/v1/distributed/test_async_llm_dp.py
tests/v1/distributed/test_async_llm_dp.py
+2
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+4
-0
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+4
-0
vllm/executor/uniproc_executor.py
vllm/executor/uniproc_executor.py
+2
-10
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+39
-3
vllm/multimodal/cache.py
vllm/multimodal/cache.py
+46
-1
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+1
-74
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+3
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+5
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+3
-1
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+6
-2
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+2
-14
vllm/v1/executor/utils.py
vllm/v1/executor/utils.py
+0
-24
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+82
-13
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+109
-10
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+4
-0
No files found.
tests/entrypoints/llm/test_mm_cache_stats.py
0 → 100644
View file @
ad430a67
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
vllm
import
LLM
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.v1.metrics.reader
import
Counter
,
Metric
from
..openai.test_vision
import
TEST_IMAGE_ASSETS
def
_make_messages
(
image_url
:
str
)
->
list
[
ChatCompletionMessageParam
]:
return
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
},
},
],
}
]
def
_get_counter_value
(
metrics
:
list
[
Metric
],
name
:
str
):
metric
=
next
(
m
for
m
in
metrics
if
m
.
name
==
name
)
assert
isinstance
(
metric
,
Counter
)
return
metric
.
value
def
_get_mm_cache_stats
(
metrics
:
list
[
Metric
]):
mm_cache_queries
=
_get_counter_value
(
metrics
,
"vllm:mm_cache_queries"
)
mm_cache_hits
=
_get_counter_value
(
metrics
,
"vllm:mm_cache_hits"
)
return
mm_cache_queries
,
mm_cache_hits
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[
TEST_IMAGE_ASSETS
[:
2
]],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"mm_processor_cache_type"
,
[
"lru"
,
"shm"
])
def
test_mm_cache_stats
(
num_gpus_available
,
image_urls
,
mm_processor_cache_type
,
):
llm
=
LLM
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
max_model_len
=
4096
,
max_num_seqs
=
5
,
enforce_eager
=
True
,
mm_processor_cache_type
=
mm_processor_cache_type
,
disable_log_stats
=
False
,
limit_mm_per_prompt
=
{
"image"
:
2
},
)
llm
.
chat
(
_make_messages
(
image_urls
[
0
]))
assert
_get_mm_cache_stats
(
llm
.
get_metrics
())
==
(
1
,
0
)
llm
.
chat
(
_make_messages
(
image_urls
[
1
]))
assert
_get_mm_cache_stats
(
llm
.
get_metrics
())
==
(
2
,
0
)
llm
.
chat
(
_make_messages
(
image_urls
[
0
]))
assert
_get_mm_cache_stats
(
llm
.
get_metrics
())
==
(
3
,
1
)
# NOTE: This only resets hit rate stats in CachingMetrics
# The raw queries and hits counts remain unaffected
llm
.
reset_mm_cache
()
llm
.
chat
(
_make_messages
(
image_urls
[
0
]))
assert
_get_mm_cache_stats
(
llm
.
get_metrics
())
==
(
4
,
1
)
llm
.
chat
(
_make_messages
(
image_urls
[
1
]))
assert
_get_mm_cache_stats
(
llm
.
get_metrics
())
==
(
5
,
1
)
tests/entrypoints/openai/test_metrics.py
View file @
ad430a67
...
...
@@ -18,10 +18,18 @@ from vllm import version
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MODELS
=
{
"text"
:
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
"multimodal"
:
"HuggingFaceTB/SmolVLM-256M-Instruct"
,
}
PREV_MINOR_VERSION
=
version
.
_prev_minor_version
()
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
list
(
MODELS
.
keys
()))
def
model_key
(
request
):
yield
request
.
param
@
pytest
.
fixture
(
scope
=
"module"
)
def
default_server_args
():
return
[
...
...
@@ -45,11 +53,12 @@ def default_server_args():
f
"--show-hidden-metrics-for-version=
{
PREV_MINOR_VERSION
}
"
,
],
)
def
server
(
default_server_args
,
request
):
def
server
(
model_key
,
default_server_args
,
request
):
if
request
.
param
:
default_server_args
.
append
(
request
.
param
)
with
RemoteOpenAIServer
(
MODEL_NAME
,
default_server_args
)
as
remote_server
:
model_name
=
MODELS
[
model_key
]
with
RemoteOpenAIServer
(
model_name
,
default_server_args
)
as
remote_server
:
yield
remote_server
...
...
@@ -60,64 +69,70 @@ async def client(server):
_PROMPT
=
"Hello my name is Robert and I love magic"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
_TOKENIZED_PROMPT
=
tokenizer
(
_PROMPT
)[
"input_ids"
]
_IMAGE_URL
=
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
_NUM_REQUESTS
=
10
_NUM_PROMPT_TOKENS_PER_REQUEST
=
len
(
_TOKENIZED_PROMPT
)
_NUM_GENERATION_TOKENS_PER_REQUEST
=
10
# {metric_family: [(suffix, expected_value)]}
EXPECTED_VALUES
=
{
"vllm:time_to_first_token_seconds"
:
[(
"_count"
,
_NUM_REQUESTS
)],
def
_get_expected_values
(
num_requests
:
int
,
prompt_ids
:
list
[
int
],
max_tokens
:
int
):
num_prompt_tokens
=
len
(
prompt_ids
)
# {metric_family: [(suffix, expected_value)]}
return
{
"vllm:time_to_first_token_seconds"
:
[(
"_count"
,
num_requests
)],
"vllm:time_per_output_token_seconds"
:
[
(
"_count"
,
_NUM_REQUESTS
*
(
_NUM_GENERATION_TOKENS_PER_REQUEST
-
1
))
(
"_count"
,
num_requests
*
(
max_tokens
-
1
))
],
"vllm:e2e_request_latency_seconds"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_queue_time_seconds"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_inference_time_seconds"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_prefill_time_seconds"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_decode_time_seconds"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:e2e_request_latency_seconds"
:
[(
"_count"
,
num_requests
)],
"vllm:request_queue_time_seconds"
:
[(
"_count"
,
num_requests
)],
"vllm:request_inference_time_seconds"
:
[(
"_count"
,
num_requests
)],
"vllm:request_prefill_time_seconds"
:
[(
"_count"
,
num_requests
)],
"vllm:request_decode_time_seconds"
:
[(
"_count"
,
num_requests
)],
"vllm:request_prompt_tokens"
:
[
(
"_sum"
,
_NUM_REQUESTS
*
_NUM_PROMPT_TOKENS_PER_REQUEST
),
(
"_count"
,
_NUM_REQUESTS
),
(
"_sum"
,
num_requests
*
num_prompt_tokens
),
(
"_count"
,
num_requests
),
],
"vllm:request_generation_tokens"
:
[
(
"_sum"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
),
(
"_count"
,
_NUM_REQUESTS
),
(
"_sum"
,
num_requests
*
max_tokens
),
(
"_count"
,
num_requests
),
],
"vllm:request_params_n"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_params_n"
:
[(
"_count"
,
num_requests
)],
"vllm:request_params_max_tokens"
:
[
(
"_sum"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
),
(
"_count"
,
_NUM_REQUESTS
),
(
"_sum"
,
num_requests
*
max_tokens
),
(
"_count"
,
num_requests
),
],
"vllm:iteration_tokens_total"
:
[
(
"_sum"
,
_NUM_REQUESTS
*
(
_NUM_PROMPT_TOKENS_PER_REQUEST
+
_NUM_GENERATION_TOKENS_PER_REQUEST
),
num_requests
*
(
num_prompt_tokens
+
max_tokens
),
),
(
"_count"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
),
],
"vllm:prompt_tokens"
:
[(
"_total"
,
_NUM_REQUESTS
*
_NUM_PROMPT_TOKENS_PER_REQUEST
)],
"vllm:generation_tokens"
:
[
(
"_total"
,
_NUM_REQUESTS
*
_NUM_PROMPT_TOKENS_PER_REQUEST
)
(
"_count"
,
num_requests
*
max_tokens
),
],
"vllm:request_success"
:
[(
"_total"
,
_NUM_REQUESTS
)],
}
"vllm:prompt_tokens"
:
[(
"_total"
,
num_requests
*
num_prompt_tokens
)],
"vllm:generation_tokens"
:
[(
"_total"
,
num_requests
*
max_tokens
)],
"vllm:request_success"
:
[(
"_total"
,
num_requests
)],
}
@
pytest
.
mark
.
asyncio
async
def
test_metrics_counts
(
server
:
RemoteOpenAIServer
,
client
:
openai
.
AsyncClient
,
model_key
:
str
,
):
for
_
in
range
(
_NUM_REQUESTS
):
if
model_key
==
"multimodal"
:
pytest
.
skip
(
"Unnecessary test"
)
model_name
=
MODELS
[
model_key
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
prompt_ids
=
tokenizer
.
encode
(
_PROMPT
)
num_requests
=
10
max_tokens
=
10
for
_
in
range
(
num_requests
):
# sending a request triggers the metrics to be logged.
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
_TOKENIZED_PROMPT
,
max_tokens
=
_NUM_GENERATION_TOKENS_PER_REQUEST
,
model
=
model_name
,
prompt
=
prompt_ids
,
max_tokens
=
max_tokens
,
)
response
=
requests
.
get
(
server
.
url_for
(
"metrics"
))
...
...
@@ -125,8 +140,9 @@ async def test_metrics_counts(
assert
response
.
status_code
==
HTTPStatus
.
OK
# Loop over all expected metric_families
for
metric_family
,
suffix_values_list
in
EXPECTED_VALUES
.
items
():
if
(
metric_family
not
in
EXPECTED_METRICS_V1
)
or
(
expected_values
=
_get_expected_values
(
num_requests
,
prompt_ids
,
max_tokens
)
for
metric_family
,
suffix_values_list
in
expected_values
.
items
():
if
metric_family
not
in
EXPECTED_METRICS_V1
or
(
not
server
.
show_hidden_metrics
and
metric_family
in
HIDDEN_DEPRECATED_METRICS
):
...
...
@@ -217,6 +233,11 @@ EXPECTED_METRICS_V1 = [
"vllm:request_decode_time_seconds_count"
,
]
EXPECTED_METRICS_MM
=
[
"vllm:mm_cache_queries"
,
"vllm:mm_cache_hits"
,
]
HIDDEN_DEPRECATED_METRICS
:
list
[
str
]
=
[
"vllm:gpu_cache_usage_perc"
,
"vllm:gpu_prefix_cache_queries"
,
...
...
@@ -231,19 +252,43 @@ HIDDEN_DEPRECATED_METRICS: list[str] = [
async
def
test_metrics_exist
(
server
:
RemoteOpenAIServer
,
client
:
openai
.
AsyncClient
,
model_key
:
str
,
):
model_name
=
MODELS
[
model_key
]
# sending a request triggers the metrics to be logged.
if
model_key
==
"text"
:
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
model_name
,
prompt
=
"Hello, my name is"
,
max_tokens
=
5
,
temperature
=
0.0
,
)
else
:
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
_IMAGE_URL
}},
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}
],
max_tokens
=
5
,
temperature
=
0.0
,
)
response
=
requests
.
get
(
server
.
url_for
(
"metrics"
))
assert
response
.
status_code
==
HTTPStatus
.
OK
for
metric
in
EXPECTED_METRICS_V1
:
expected_metrics
=
EXPECTED_METRICS_V1
if
model_key
==
"multimodal"
:
# NOTE: Don't use in-place assignment
expected_metrics
=
expected_metrics
+
EXPECTED_METRICS_MM
for
metric
in
expected_metrics
:
if
metric
in
HIDDEN_DEPRECATED_METRICS
and
not
server
.
show_hidden_metrics
:
continue
assert
metric
in
response
.
text
...
...
@@ -253,9 +298,14 @@ async def test_metrics_exist(
async
def
test_abort_metrics_reset
(
server
:
RemoteOpenAIServer
,
client
:
openai
.
AsyncClient
,
model_key
:
str
,
):
model_name
=
MODELS
[
model_key
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
prompt_ids
=
tokenizer
.
encode
(
_PROMPT
)
running_requests
,
waiting_requests
,
kv_cache_usage
=
_get_running_metrics_from_api
(
server
server
,
)
# Expect no running requests or kvcache usage
...
...
@@ -268,8 +318,8 @@ async def test_abort_metrics_reset(
for
_
in
range
(
3
):
task
=
asyncio
.
create_task
(
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
_TOKENIZED_PROMPT
,
model
=
model_name
,
prompt
=
prompt_ids
,
max_tokens
=
100
,
# Long generation to give time to abort
temperature
=
0.0
,
)
...
...
@@ -281,7 +331,7 @@ async def test_abort_metrics_reset(
# Check that we have running requests
running_requests
,
waiting_requests
,
kv_cache_usage
=
_get_running_metrics_from_api
(
server
server
,
)
# Expect running requests and kvcache usage
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
ad430a67
...
...
@@ -20,7 +20,6 @@ from vllm.v1.core.kv_cache_utils import (
BlockHash
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
generate_scheduler_kv_cache_config
,
...
...
@@ -42,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec
,
UniformTypeKVCacheSpecs
,
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.metrics.stats
import
CachingMetrics
,
PrefixCacheStats
from
vllm.v1.request
import
Request
pytestmark
=
pytest
.
mark
.
cpu_test
...
...
@@ -536,7 +535,7 @@ def test_metrics():
"""
Test the prefix caching metrics.
"""
metrics
=
Prefix
CachingMetrics
(
max_recent_requests
=
5
)
metrics
=
CachingMetrics
(
max_recent_requests
=
5
)
assert
metrics
.
hit_rate
==
0.0
metrics
.
observe
(
_stats
(
1
,
20
,
9
))
...
...
@@ -568,7 +567,7 @@ def test_metrics_empty_stats():
"""
Test the prefix caching metrics with empty stats.
"""
metrics
=
Prefix
CachingMetrics
(
max_recent_requests
=
5
)
metrics
=
CachingMetrics
(
max_recent_requests
=
5
)
metrics
.
observe
(
_stats
(
0
,
0
,
0
))
metrics
.
observe
(
_stats
(
1
,
20
,
9
))
metrics
.
observe
(
_stats
(
0
,
0
,
0
))
...
...
tests/v1/distributed/test_async_llm_dp.py
View file @
ad430a67
...
...
@@ -17,7 +17,7 @@ from vllm.sampling_params import RequestOutputKind
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.core_client
import
DPAsyncMPClient
from
vllm.v1.metrics.loggers
import
StatLoggerBase
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.metrics.stats
import
IterationStats
,
MultiModalCacheStats
,
SchedulerStats
DP_SIZE
=
int
(
os
.
getenv
(
"DP_SIZE"
,
2
))
...
...
@@ -93,6 +93,7 @@ async def test_load(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
],
mm_cache_stats
:
Optional
[
MultiModalCacheStats
]
=
None
,
engine_idx
:
int
=
0
,
):
if
iteration_stats
:
...
...
vllm/entrypoints/llm.py
View file @
ad430a67
...
...
@@ -354,6 +354,10 @@ class LLM:
else
:
self
.
llm_engine
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
def
reset_mm_cache
(
self
)
->
None
:
self
.
processor
.
clear_mm_cache
()
self
.
llm_engine
.
reset_mm_cache
()
def
get_default_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
ad430a67
...
...
@@ -274,6 +274,10 @@ class OpenAIServing:
self
.
model_config
=
self
.
models
.
model_config
self
.
max_model_len
=
self
.
model_config
.
max_model_len
async
def
reset_mm_cache
(
self
)
->
None
:
self
.
processor
.
clear_mm_cache
()
await
self
.
engine_client
.
reset_mm_cache
()
async
def
beam_search
(
self
,
prompt
:
PromptType
,
...
...
vllm/executor/executor_base.py
View file @
ad430a67
...
...
@@ -169,6 +169,10 @@ class ExecutorBase(ABC):
assert
s
==
sets
[
0
],
"All workers should have the same LORAs."
return
sets
[
0
]
def
reset_mm_cache
(
self
)
->
None
:
"""Reset the multi-modal cache in each worker."""
self
.
collective_rpc
(
"reset_mm_cache"
)
def
start_profile
(
self
)
->
None
:
self
.
collective_rpc
(
"start_profile"
)
...
...
vllm/executor/uniproc_executor.py
View file @
ad430a67
...
...
@@ -12,11 +12,8 @@ import torch.distributed as dist
import
vllm.envs
as
envs
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.cache
import
worker_receiver_cache_from_config
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
,
run_method
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.executor.utils
import
get_and_update_mm_cache
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -30,16 +27,13 @@ class UniProcExecutor(ExecutorBase):
"""Initialize the worker and load the model."""
self
.
driver_worker
=
WorkerWrapperBase
(
vllm_config
=
self
.
vllm_config
,
rpc_rank
=
0
)
distributed_init_method
,
rank
,
local_rank
=
self
.
_distributed_args
()
is_driver_worker
=
True
kwargs
=
dict
(
vllm_config
=
self
.
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
)
self
.
mm_receiver_cache
=
worker_receiver_cache_from_config
(
self
.
vllm_config
,
MULTIMODAL_REGISTRY
,
Lock
()
is_driver_worker
=
True
,
shared_worker_lock
=
Lock
(),
)
self
.
async_output_thread
:
Optional
[
ThreadPoolExecutor
]
=
None
...
...
@@ -74,8 +68,6 @@ class UniProcExecutor(ExecutorBase):
)
->
list
[
Any
]:
if
kwargs
is
None
:
kwargs
=
{}
if
self
.
mm_receiver_cache
is
not
None
and
method
==
"execute_model"
:
get_and_update_mm_cache
(
self
.
mm_receiver_cache
,
args
)
if
not
non_block
:
return
[
run_method
(
self
.
driver_worker
,
method
,
args
,
kwargs
)]
...
...
vllm/inputs/preprocess.py
View file @
ad430a67
...
...
@@ -19,6 +19,7 @@ from vllm.multimodal.inputs import (
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils.jsontree
import
json_iter_leaves
from
vllm.v1.metrics.stats
import
MultiModalCacheStats
from
.data
import
(
DecoderOnlyInputs
,
...
...
@@ -56,6 +57,8 @@ class InputPreprocessor:
self
.
mm_registry
=
mm_registry
self
.
mm_processor_cache
=
mm_processor_cache
self
.
mm_cache_stats
=
MultiModalCacheStats
()
if
mm_processor_cache
else
None
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
...
...
@@ -664,14 +667,13 @@ class InputPreprocessor:
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
)
def
preprocess
(
def
_
preprocess
(
self
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
ProcessorInputs
:
"""Preprocess the input prompt."""
if
self
.
model_config
.
is_encoder_decoder
:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder.
...
...
@@ -694,6 +696,40 @@ class InputPreprocessor:
mm_uuids
=
mm_uuids
,
)
def
clear_cache
(
self
)
->
None
:
def
preprocess
(
self
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
ProcessorInputs
:
"""Preprocess the input prompt."""
res
=
self
.
_preprocess
(
prompt
,
tokenization_kwargs
,
mm_uuids
=
mm_uuids
,
)
if
self
.
mm_processor_cache
and
self
.
mm_cache_stats
is
not
None
:
delta
=
self
.
mm_processor_cache
.
make_stats
(
delta
=
True
)
self
.
mm_cache_stats
.
requests
+=
1
self
.
mm_cache_stats
.
queries
+=
delta
.
total
self
.
mm_cache_stats
.
hits
+=
delta
.
hits
return
res
def
stat_mm_cache
(
self
)
->
Optional
[
MultiModalCacheStats
]:
mm_cache_stats
=
self
.
mm_cache_stats
if
mm_cache_stats
is
None
:
return
None
self
.
mm_cache_stats
=
MultiModalCacheStats
()
return
mm_cache_stats
def
clear_mm_cache
(
self
)
->
None
:
if
self
.
mm_processor_cache
is
not
None
:
self
.
mm_processor_cache
.
clear_cache
()
if
self
.
mm_cache_stats
is
not
None
:
self
.
mm_cache_stats
.
reset
=
True
vllm/multimodal/cache.py
View file @
ad430a67
...
...
@@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.shm_object_storage import (
from
vllm.envs
import
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
from
vllm.logger
import
init_logger
from
vllm.utils
import
GiB_bytes
,
MiB_bytes
from
vllm.utils.cache
import
LRUCache
from
vllm.utils.cache
import
CacheInfo
,
LRUCache
from
vllm.utils.jsontree
import
json_count_leaves
,
json_map_leaves
,
json_reduce_leaves
from
.inputs
import
(
...
...
@@ -302,6 +302,16 @@ class BaseMultiModalProcessorCache(
"""
return
[
self
.
is_cached_item
(
mm_hash
)
for
mm_hash
in
mm_hashes
]
@
abstractmethod
def
make_stats
(
self
,
*
,
delta
:
bool
=
False
)
->
CacheInfo
:
"""
Get (and reset) the multi-modal cache stats.
Returns:
The current multi-modal caching stats.
"""
raise
NotImplementedError
class
MultiModalProcessorOnlyCache
(
BaseMultiModalProcessorCache
):
"""
...
...
@@ -347,6 +357,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
def
clear_cache
(
self
)
->
None
:
self
.
_cache
.
clear
()
@
override
def
make_stats
(
self
,
*
,
delta
:
bool
=
False
)
->
CacheInfo
:
return
self
.
_cache
.
stat
(
delta
=
delta
)
class
MultiModalProcessorSenderCache
(
BaseMultiModalProcessorCache
):
"""
...
...
@@ -397,6 +411,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
def
clear_cache
(
self
)
->
None
:
self
.
_cache
.
clear
()
@
override
def
make_stats
(
self
,
*
,
delta
:
bool
=
False
)
->
CacheInfo
:
return
self
.
_cache
.
stat
(
delta
=
delta
)
class
ShmObjectStoreSenderCache
(
BaseMultiModalProcessorCache
):
"""
...
...
@@ -430,6 +448,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
# cache (prompt_updates, modality) for P0 only
self
.
_p0_cache
:
dict
[
str
,
tuple
[
Sequence
[
ResolvedPromptUpdate
],
str
]]
=
{}
self
.
_hits
=
0
self
.
_total
=
0
self
.
_last_info
=
CacheInfo
(
hits
=
0
,
total
=
0
)
def
_stat
(
self
,
*
,
delta
:
bool
=
False
)
->
CacheInfo
:
info
=
CacheInfo
(
hits
=
self
.
_hits
,
total
=
self
.
_total
)
if
delta
:
info_delta
=
info
-
self
.
_last_info
self
.
_last_info
=
info
info
=
info_delta
return
info
@
override
def
is_cached_item
(
self
,
mm_hash
:
str
)
->
bool
:
return
self
.
_shm_cache
.
is_cached
(
mm_hash
)
...
...
@@ -441,12 +473,17 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
mm_hash
:
str
,
)
->
MultiModalProcessorCacheOutItem
:
if
self
.
_shm_cache
.
is_cached
(
mm_hash
):
self
.
_hits
+=
1
self
.
_total
+=
1
address
,
monotonic_id
=
self
.
_shm_cache
.
get_cached
(
mm_hash
)
prompt_updates
,
modality
=
self
.
_p0_cache
[
mm_hash
]
return
self
.
address_as_item
(
address
,
monotonic_id
,
modality
),
prompt_updates
assert
mm_item
is
not
None
,
f
"Expected a cached item for
{
mm_hash
=
}
"
self
.
_total
+=
1
try
:
address
,
monotonic_id
=
self
.
_shm_cache
.
put
(
mm_hash
,
mm_item
[
0
])
# Try to remove dangling items if p0 cache is too large.
...
...
@@ -469,6 +506,14 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
self
.
_shm_cache
.
clear
()
self
.
_p0_cache
.
clear
()
self
.
_hits
=
0
self
.
_total
=
0
self
.
_last_info
=
CacheInfo
(
hits
=
0
,
total
=
0
)
@
override
def
make_stats
(
self
,
*
,
delta
:
bool
=
False
)
->
CacheInfo
:
return
self
.
_stat
(
delta
=
delta
)
def
remove_dangling_items
(
self
)
->
None
:
"""Remove items that are no longer in the shared memory cache."""
cached_hashes
=
self
.
_shm_cache
.
key_index
.
keys
()
...
...
vllm/v1/core/kv_cache_utils.py
View file @
ad430a67
...
...
@@ -4,7 +4,7 @@
import
copy
import
os
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
from
collections.abc
import
Iterable
,
Sequence
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
NewType
,
Optional
,
Union
...
...
@@ -23,7 +23,6 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec
,
UniformTypeKVCacheSpecs
,
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
# BlockHash represents the hash of a single KV-cache block used for
...
...
@@ -101,78 +100,6 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]):
NONE_HASH
=
BlockHash
(
hash_fn
(
hash_seed
))
class
PrefixCachingMetrics
:
"""Metrics for prefix caching with a hit rate of the max recent N requests.
Args:
max_recent_requests: The number of the max recent requests to aggregate.
Defaults to 1000.
"""
def
__init__
(
self
,
max_recent_requests
:
int
=
1000
):
self
.
max_recent_requests
=
max_recent_requests
# The current aggregated values.
self
.
aggregated_requests
=
0
self
.
aggregated_query_total
=
0
self
.
aggregated_query_hit
=
0
# A deque of (requests, queries, hits) for the most recent requests.
self
.
query_queue
:
deque
[
tuple
[
int
,
int
,
int
]]
=
deque
()
def
observe
(
self
,
stats
:
PrefixCacheStats
):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `max_recent_requests` requests, the oldest set
of requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if
stats
.
reset
:
self
.
reset
()
# DO NOT appending empty stats to avoid helpful info get kicked out
# due to sliding window.
if
stats
.
requests
==
0
:
return
# Update the metrics.
self
.
query_queue
.
append
((
stats
.
requests
,
stats
.
queries
,
stats
.
hits
))
self
.
aggregated_requests
+=
stats
.
requests
self
.
aggregated_query_total
+=
stats
.
queries
self
.
aggregated_query_hit
+=
stats
.
hits
# Remove the oldest stats until number of requests does not exceed
# the limit.
# NOTE: We preserve the latest added stats regardless.
while
(
len
(
self
.
query_queue
)
>
1
and
self
.
aggregated_requests
>
self
.
max_recent_requests
):
old_requests
,
old_queries
,
old_hits
=
self
.
query_queue
.
popleft
()
self
.
aggregated_requests
-=
old_requests
self
.
aggregated_query_total
-=
old_queries
self
.
aggregated_query_hit
-=
old_hits
def
reset
(
self
):
"""Reset the metrics."""
self
.
aggregated_requests
=
0
self
.
aggregated_query_total
=
0
self
.
aggregated_query_hit
=
0
self
.
query_queue
.
clear
()
@
property
def
hit_rate
(
self
)
->
float
:
"""Calculate the hit rate for the past N requests."""
if
self
.
aggregated_query_total
==
0
:
return
0.0
return
self
.
aggregated_query_hit
/
self
.
aggregated_query_total
@
dataclass
class
KVCacheBlock
:
"""KV-cache block metadata."""
...
...
vllm/v1/engine/async_llm.py
View file @
ad430a67
...
...
@@ -463,6 +463,7 @@ class AsyncLLM(EngineClient):
output_processor
=
self
.
output_processor
log_stats
=
self
.
log_stats
logger_manager
=
self
.
logger_manager
processor
=
self
.
processor
async
def
output_handler
():
try
:
...
...
@@ -511,6 +512,7 @@ class AsyncLLM(EngineClient):
engine_idx
=
outputs
.
engine_index
,
scheduler_stats
=
outputs
.
scheduler_stats
,
iteration_stats
=
iteration_stats
,
mm_cache_stats
=
processor
.
stat_mm_cache
(),
)
except
Exception
as
e
:
logger
.
exception
(
"AsyncLLM output_handler failed."
)
...
...
@@ -660,7 +662,7 @@ class AsyncLLM(EngineClient):
await
asyncio
.
gather
(
*
coros
)
async
def
reset_mm_cache
(
self
)
->
None
:
self
.
processor
.
clear_cache
()
self
.
processor
.
clear_
mm_
cache
()
await
self
.
engine_core
.
reset_mm_cache_async
()
async
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
None
:
...
...
vllm/v1/engine/core.py
View file @
ad430a67
...
...
@@ -319,7 +319,7 @@ class EngineCore:
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# type: ignore
)
return
(
engine_core_outputs
,
scheduler_output
.
total_num_scheduled_tokens
>
0
)
...
...
@@ -400,16 +400,19 @@ class EngineCore:
def
reset_mm_cache
(
self
):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0
processor, P0 mirror, P1 mirro
r)
# re-sync the internal caches (P0
sender, P1 receive
r)
if
self
.
scheduler
.
has_unfinished_requests
():
logger
.
warning
(
"Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches."
)
# The cache either exists in EngineCore or WorkerWrapperBase
if
self
.
mm_receiver_cache
is
not
None
:
self
.
mm_receiver_cache
.
clear_cache
()
self
.
model_executor
.
reset_mm_cache
()
def
reset_prefix_cache
(
self
):
self
.
scheduler
.
reset_prefix_cache
()
...
...
vllm/v1/engine/llm_engine.py
View file @
ad430a67
...
...
@@ -306,9 +306,11 @@ class LLMEngine:
# 4) Record stats
if
self
.
logger_manager
is
not
None
:
assert
outputs
.
scheduler_stats
is
not
None
self
.
logger_manager
.
record
(
scheduler_stats
=
outputs
.
scheduler_stats
,
iteration_stats
=
iteration_stats
,
mm_cache_stats
=
self
.
processor
.
stat_mm_cache
(),
)
self
.
do_log_stats_with_interval
()
...
...
@@ -321,7 +323,7 @@ class LLMEngine:
self
.
engine_core
.
profile
(
False
)
def
reset_mm_cache
(
self
):
self
.
processor
.
clear_cache
()
self
.
processor
.
clear_
mm_
cache
()
self
.
engine_core
.
reset_mm_cache
()
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
):
...
...
vllm/v1/engine/processor.py
View file @
ad430a67
...
...
@@ -21,6 +21,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.metrics.stats
import
MultiModalCacheStats
from
vllm.v1.structured_output.backend_guidance
import
validate_guidance_grammar
from
vllm.v1.structured_output.backend_lm_format_enforcer
import
(
validate_structured_output_request_lm_format_enforcer
,
...
...
@@ -573,5 +574,8 @@ class Processor:
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def
clear_cache
(
self
)
->
None
:
self
.
input_preprocessor
.
clear_cache
()
def
stat_mm_cache
(
self
)
->
Optional
[
MultiModalCacheStats
]:
return
self
.
input_preprocessor
.
stat_mm_cache
()
def
clear_mm_cache
(
self
)
->
None
:
self
.
input_preprocessor
.
clear_mm_cache
()
vllm/v1/executor/multiproc_executor.py
View file @
ad430a67
...
...
@@ -33,8 +33,6 @@ from vllm.distributed.parallel_state import (
get_tp_group
,
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.cache
import
worker_receiver_cache_from_config
from
vllm.utils
import
(
_maybe_force_spawn
,
decorate_logs
,
...
...
@@ -46,7 +44,6 @@ from vllm.utils import (
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.executor.abstract
import
Executor
,
FailureCallback
from
vllm.v1.executor.utils
import
get_and_update_mm_cache
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
,
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -422,6 +419,7 @@ class WorkerProc:
"rank"
:
rank
,
"distributed_init_method"
:
distributed_init_method
,
"is_driver_worker"
:
is_driver_worker
,
"shared_worker_lock"
:
shared_worker_lock
,
}
wrapper
.
init_worker
(
all_kwargs
)
self
.
worker
=
wrapper
...
...
@@ -445,11 +443,6 @@ class WorkerProc:
)
self
.
async_output_copy_thread
.
start
()
# Initialize multimodal receiver cache if needed
self
.
mm_receiver_cache
=
worker_receiver_cache_from_config
(
vllm_config
,
MULTIMODAL_REGISTRY
,
shared_worker_lock
)
# Initialize device
self
.
worker
.
init_device
()
...
...
@@ -692,12 +685,7 @@ class WorkerProc:
func
=
getattr
(
self
.
worker
,
method
)
elif
isinstance
(
method
,
bytes
):
func
=
partial
(
cloudpickle
.
loads
(
method
),
self
.
worker
)
# retrieve from shm cache if available
if
(
self
.
mm_receiver_cache
is
not
None
and
func
.
__name__
==
"execute_model"
):
get_and_update_mm_cache
(
self
.
mm_receiver_cache
,
args
)
output
=
func
(
*
args
,
**
kwargs
)
except
Exception
as
e
:
# Notes have been introduced in python 3.11
...
...
vllm/v1/executor/utils.py
deleted
100644 → 0
View file @
6f0f570c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.multimodal.cache
import
ShmObjectStoreReceiverCache
from
vllm.v1.core.sched.output
import
SchedulerOutput
def
get_and_update_mm_cache
(
receiver_cache
:
ShmObjectStoreReceiverCache
,
args
:
tuple
[
SchedulerOutput
],
)
->
None
:
"""
For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory
cache as needed.
Args:
receiver_cache: The receiver cache to update.
args: According to the collective_rpc call of execute_model method in
executor, args is a tuple of only one SchedulerOutput element.
"""
scheduler_output
=
args
[
0
]
for
request_data
in
scheduler_output
.
scheduled_new_reqs
:
request_data
.
mm_features
=
receiver_cache
.
get_and_update_features
(
request_data
.
mm_features
)
vllm/v1/metrics/loggers.py
View file @
ad430a67
...
...
@@ -11,10 +11,14 @@ import prometheus_client
from
vllm.config
import
SupportsMetricsInfo
,
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorLogging
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.prometheus
import
unregister_vllm_metrics
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.metrics.stats
import
(
CachingMetrics
,
IterationStats
,
MultiModalCacheStats
,
SchedulerStats
,
)
from
vllm.v1.spec_decode.metrics
import
SpecDecodingLogging
,
SpecDecodingProm
logger
=
init_logger
(
__name__
)
...
...
@@ -38,6 +42,7 @@ class StatLoggerBase(ABC):
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
],
mm_cache_stats
:
Optional
[
MultiModalCacheStats
]
=
None
,
engine_idx
:
int
=
0
,
):
...
...
...
@@ -53,10 +58,15 @@ class LoggingStatLogger(StatLoggerBase):
self
.
engine_index
=
engine_index
self
.
vllm_config
=
vllm_config
self
.
_reset
(
time
.
monotonic
())
self
.
last_scheduler_stats
=
SchedulerStats
()
# Prefix cache metrics. This cannot be reset.
self
.
last_mm_cache_stats
:
Optional
[
MultiModalCacheStats
]
=
None
# Caching metrics. This cannot be reset.
# TODO: Make the interval configurable.
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
prefix_caching_metrics
=
CachingMetrics
()
self
.
mm_caching_metrics
=
CachingMetrics
()
self
.
spec_decoding_logging
=
SpecDecodingLogging
()
kv_tranfer_config
=
self
.
vllm_config
.
kv_transfer_config
self
.
kv_connector_logging
=
KVConnectorLogging
(
kv_tranfer_config
)
...
...
@@ -86,6 +96,7 @@ class LoggingStatLogger(StatLoggerBase):
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
],
mm_cache_stats
:
Optional
[
MultiModalCacheStats
]
=
None
,
engine_idx
:
int
=
0
,
):
"""Log Stats to standard output."""
...
...
@@ -101,6 +112,11 @@ class LoggingStatLogger(StatLoggerBase):
self
.
kv_connector_logging
.
observe
(
kv_connector_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
if
mm_cache_stats
:
self
.
mm_caching_metrics
.
observe
(
mm_cache_stats
)
self
.
last_mm_cache_stats
=
mm_cache_stats
def
log
(
self
):
now
=
time
.
monotonic
()
prompt_throughput
=
self
.
_get_throughput
(
self
.
num_prompt_tokens
,
now
)
...
...
@@ -125,21 +141,32 @@ class LoggingStatLogger(StatLoggerBase):
self
.
last_prompt_throughput
=
prompt_throughput
# Format and print output.
log_
fn
(
"
Engine %03d: "
"Avg
prompt
throughput: %.1f tokens/s
,
"
"
Avg generation throughput: %.1f tokens/s, "
"
Running: %d reqs,
Waiting: %d reqs
,
"
"GPU KV cache usage: %.1f%%
,
"
log_
parts
=
[
"
Avg prompt throughput: %.1f tokens/s"
,
"Avg
generation
throughput: %.1f tokens/s"
,
"
Running: %d reqs"
,
"Waiting: %d reqs"
,
"GPU KV cache usage: %.1f%%"
,
"Prefix cache hit rate: %.1f%%"
,
self
.
engine_index
,
]
log_args
=
[
prompt_throughput
,
generation_throughput
,
scheduler_stats
.
num_running_reqs
,
scheduler_stats
.
num_waiting_reqs
,
scheduler_stats
.
kv_cache_usage
*
100
,
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
]
if
self
.
last_mm_cache_stats
:
log_parts
.
append
(
"MM cache hit rate: %.1f%%"
)
log_args
.
append
(
self
.
mm_caching_metrics
.
hit_rate
*
100
)
log_fn
(
"Engine %03d: "
+
", "
.
join
(
log_parts
),
self
.
engine_index
,
*
log_args
,
)
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
kv_connector_logging
.
log
(
log_fn
=
log_fn
)
...
...
@@ -288,6 +315,32 @@ class PrometheusStatLogger(StatLoggerBase):
counter_prefix_cache_hits
,
engine_indexes
,
model_name
)
#
# Multi-modal cache
#
counter_mm_cache_queries
=
self
.
_counter_cls
(
name
=
"vllm:mm_cache_queries"
,
documentation
=
(
"Multi-modal cache queries, in terms of number of queried items."
),
labelnames
=
labelnames
,
)
self
.
counter_mm_cache_queries
=
make_per_engine
(
counter_mm_cache_queries
,
engine_indexes
,
model_name
)
counter_mm_cache_hits
=
self
.
_counter_cls
(
name
=
"vllm:mm_cache_hits"
,
documentation
=
(
"Multi-modal cache hits, in terms of number of cached items."
),
labelnames
=
labelnames
,
)
self
.
counter_mm_cache_hits
=
make_per_engine
(
counter_mm_cache_hits
,
engine_indexes
,
model_name
)
#
# Counters
#
...
...
@@ -657,6 +710,7 @@ class PrometheusStatLogger(StatLoggerBase):
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
],
mm_cache_stats
:
Optional
[
MultiModalCacheStats
]
=
None
,
engine_idx
:
int
=
0
,
):
"""Log to prometheus."""
...
...
@@ -694,6 +748,10 @@ class PrometheusStatLogger(StatLoggerBase):
scheduler_stats
.
spec_decoding_stats
,
engine_idx
)
if
mm_cache_stats
is
not
None
:
self
.
counter_mm_cache_queries
[
engine_idx
].
inc
(
mm_cache_stats
.
queries
)
self
.
counter_mm_cache_hits
[
engine_idx
].
inc
(
mm_cache_stats
.
hits
)
if
iteration_stats
is
None
:
return
...
...
@@ -871,6 +929,7 @@ class StatLoggerManager:
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
],
mm_cache_stats
:
Optional
[
MultiModalCacheStats
]
=
None
,
engine_idx
:
Optional
[
int
]
=
None
,
):
if
engine_idx
is
None
:
...
...
@@ -878,9 +937,19 @@ class StatLoggerManager:
per_engine_loggers
=
self
.
per_engine_logger_dict
[
engine_idx
]
for
logger
in
per_engine_loggers
:
logger
.
record
(
scheduler_stats
,
iteration_stats
,
engine_idx
)
logger
.
record
(
scheduler_stats
,
iteration_stats
,
mm_cache_stats
=
mm_cache_stats
,
engine_idx
=
engine_idx
,
)
self
.
prometheus_logger
.
record
(
scheduler_stats
,
iteration_stats
,
engine_idx
)
self
.
prometheus_logger
.
record
(
scheduler_stats
,
iteration_stats
,
mm_cache_stats
=
mm_cache_stats
,
engine_idx
=
engine_idx
,
)
def
log
(
self
):
for
per_engine_loggers
in
self
.
per_engine_logger_dict
.
values
():
...
...
vllm/v1/metrics/stats.py
View file @
ad430a67
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
...
...
@@ -13,24 +14,122 @@ if TYPE_CHECKING:
@
dataclass
class
Prefix
CacheStats
:
"""Stores
prefix
cache hit statistics."""
class
Base
CacheStats
:
"""Stores cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset
:
bool
=
False
# The number of new requests in this update.
"""Whether the cache was reset."""
requests
:
int
=
0
#
The number of que
rie
s in th
ese requests. Note that "queries" here
# means the number of tokens that were queried from the cache.
"""
The number of
re
que
st
s in th
is update."""
queries
:
int
=
0
# The number of hits in these requests.
"""The number of queries in these requests."""
hits
:
int
=
0
# The number of previously preempted requests in this update.
"""The number of hits in these requests."""
class
CachingMetrics
:
"""Metrics for caching with a hit rate of the most recent N requests.
Args:
interval: The number of the most recent requests to aggregate.
Defaults to 1000.
"""
def
__init__
(
self
,
max_recent_requests
:
int
=
1000
)
->
None
:
super
().
__init__
()
self
.
max_recent_requests
=
max_recent_requests
# The current aggregated values.
self
.
aggregated_requests
=
0
self
.
aggregated_query_total
=
0
self
.
aggregated_query_hit
=
0
# A deque of (requests, queries, hits) for the most recent requests.
self
.
query_queue
=
deque
[
tuple
[
int
,
int
,
int
]]()
def
observe
(
self
,
stats
:
BaseCacheStats
):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `max_recent_requests` requests, the oldest set
of requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if
stats
.
reset
:
self
.
reset
()
# DO NOT appending empty stats to avoid helpful info get kicked out
# due to sliding window.
if
stats
.
requests
==
0
:
return
# Update the metrics.
self
.
query_queue
.
append
((
stats
.
requests
,
stats
.
queries
,
stats
.
hits
))
self
.
aggregated_requests
+=
stats
.
requests
self
.
aggregated_query_total
+=
stats
.
queries
self
.
aggregated_query_hit
+=
stats
.
hits
# Remove the oldest stats until number of requests does not exceed
# the limit.
# NOTE: We preserve the latest added stats regardless.
while
(
len
(
self
.
query_queue
)
>
1
and
self
.
aggregated_requests
>
self
.
max_recent_requests
):
old_requests
,
old_queries
,
old_hits
=
self
.
query_queue
.
popleft
()
self
.
aggregated_requests
-=
old_requests
self
.
aggregated_query_total
-=
old_queries
self
.
aggregated_query_hit
-=
old_hits
def
reset
(
self
):
"""Reset the metrics."""
self
.
aggregated_requests
=
0
self
.
aggregated_query_total
=
0
self
.
aggregated_query_hit
=
0
self
.
query_queue
.
clear
()
@
property
def
hit_rate
(
self
)
->
float
:
"""Calculate the hit rate for the past N requests."""
if
self
.
aggregated_query_total
==
0
:
return
0.0
return
self
.
aggregated_query_hit
/
self
.
aggregated_query_total
@
dataclass
class
PrefixCacheStats
(
BaseCacheStats
):
"""
Stores prefix cache hit statistics.
- `reset`: Whether `reset_prefix_cache` was invoked.
- `queries`: Refers to the number of tokens that were queried.
"""
preempted_requests
:
int
=
0
# The `queries` number for preempted requests.
"""The number of previously preempted requests in this update."""
preempted_queries
:
int
=
0
# The `hits` number for preempted requests.
"""The `queries` number for preempted requests."""
preempted_hits
:
int
=
0
"""The `hits` number for preempted requests."""
@
dataclass
class
MultiModalCacheStats
(
BaseCacheStats
):
"""
Stores multi-modal cache hit statistics.
- `reset`: Whether `reset_mm_cache` was invoked.
- `queries`: Refers to the number of multi-modal data items
that were queried.
"""
@
dataclass
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
ad430a67
...
...
@@ -508,6 +508,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory
=
self
.
pin_memory
,
)
def
reset_mm_cache
(
self
)
->
None
:
if
self
.
mm_budget
:
self
.
mm_budget
.
reset_cache
()
def
_get_positions
(
self
,
num_tokens
:
Any
):
if
isinstance
(
num_tokens
,
int
):
if
self
.
uses_mrope
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment