Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad430a67
Unverified
Commit
ad430a67
authored
Oct 10, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 10, 2025
Browse files
[Metrics] Log multi-modal cache stats and fix reset (#26285)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
6f0f570c
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
12 deletions
+80
-12
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+3
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+4
-0
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+3
-0
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+4
-0
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+66
-12
No files found.
vllm/v1/worker/gpu_worker.py
View file @
ad430a67
...
@@ -442,6 +442,9 @@ class Worker(WorkerBase):
...
@@ -442,6 +442,9 @@ class Worker(WorkerBase):
# the model initialization and profiling.
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
def
reset_mm_cache
(
self
)
->
None
:
self
.
model_runner
.
reset_mm_cache
()
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
return
self
.
model_runner
.
get_model
()
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
ad430a67
...
@@ -371,6 +371,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -371,6 +371,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else
:
else
:
self
.
sample_from_logits_func
=
self
.
sample_from_logits
self
.
sample_from_logits_func
=
self
.
sample_from_logits
def
reset_mm_cache
(
self
)
->
None
:
if
self
.
mm_budget
:
self
.
mm_budget
.
reset_cache
()
def
_update_num_xla_graphs
(
self
,
case_str
):
def
_update_num_xla_graphs
(
self
,
case_str
):
check_comp
=
self
.
check_recompilation
and
not
self
.
enforce_eager
check_comp
=
self
.
check_recompilation
and
not
self
.
enforce_eager
if
not
check_comp
:
if
not
check_comp
:
...
...
vllm/v1/worker/tpu_worker.py
View file @
ad430a67
...
@@ -293,6 +293,9 @@ class TPUWorker:
...
@@ -293,6 +293,9 @@ class TPUWorker:
# the model initialization and profiling.
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
def
reset_mm_cache
(
self
)
->
None
:
self
.
model_runner
.
reset_mm_cache
()
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
return
self
.
model_runner
.
get_model
()
...
...
vllm/v1/worker/utils.py
View file @
ad430a67
...
@@ -126,6 +126,10 @@ class MultiModalBudget:
...
@@ -126,6 +126,10 @@ class MultiModalBudget:
return
max_items_per_prompt
,
max_items_per_batch
return
max_items_per_prompt
,
max_items_per_batch
def
reset_cache
(
self
)
->
None
:
if
self
.
cache
is
not
None
:
self
.
cache
.
clear_cache
()
@
dataclass
@
dataclass
class
AttentionGroup
:
class
AttentionGroup
:
...
...
vllm/v1/worker/worker_base.py
View file @
ad430a67
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
os
import
os
from
typing
import
Any
,
Callable
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
TypeVar
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -12,7 +12,8 @@ import torch.nn as nn
...
@@ -12,7 +12,8 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.cache
import
worker_receiver_cache_from_config
from
vllm.utils
import
(
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
enable_trace_function_call_for_thread
,
resolve_obj_by_qualname
,
resolve_obj_by_qualname
,
...
@@ -21,7 +22,10 @@ from vllm.utils import (
...
@@ -21,7 +22,10 @@ from vllm.utils import (
warn_for_unimplemented_methods
,
warn_for_unimplemented_methods
,
)
)
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
from
vllm.v1.outputs
import
SamplerOutput
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -103,6 +107,11 @@ class WorkerBase:
...
@@ -103,6 +107,11 @@ class WorkerBase:
"""Initialize the KV cache with the given size in blocks."""
"""Initialize the KV cache with the given size in blocks."""
raise
NotImplementedError
raise
NotImplementedError
def
reset_mm_cache
(
self
)
->
None
:
reset_fn
=
getattr
(
self
.
model_runner
,
"reset_mm_cache"
,
None
)
if
callable
(
reset_fn
):
reset_fn
()
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -114,9 +123,7 @@ class WorkerBase:
...
@@ -114,9 +123,7 @@ class WorkerBase:
"""Load model onto target device."""
"""Load model onto target device."""
raise
NotImplementedError
raise
NotImplementedError
def
execute_model
(
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
)
->
ModelRunnerOutput
:
self
,
execute_model_req
:
ExecuteModelRequest
|
None
=
None
)
->
list
[
SamplerOutput
]
|
None
:
raise
NotImplementedError
raise
NotImplementedError
def
start_worker_execution_loop
(
self
)
->
None
:
def
start_worker_execution_loop
(
self
)
->
None
:
...
@@ -125,11 +132,7 @@ class WorkerBase:
...
@@ -125,11 +132,7 @@ class WorkerBase:
You can stop the loop by executing a driver worker with an empty output.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
See `stop_remote_worker_execution_loop` for more details.
"""
"""
with
self
.
current_platform
.
inference_mode
():
raise
NotImplementedError
(
"Dead V0 code"
)
while
True
:
output
=
self
.
execute_model
(
execute_model_req
=
None
)
if
output
is
None
:
return
None
def
determine_num_available_blocks
(
self
)
->
tuple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
tuple
[
int
,
int
]:
"""Determine the number of available blocks for the GPU KV cache and
"""Determine the number of available blocks for the GPU KV cache and
...
@@ -289,6 +292,28 @@ class WorkerWrapperBase:
...
@@ -289,6 +292,28 @@ class WorkerWrapperBase:
worker_class
,
worker_class
,
extended_calls
,
extended_calls
,
)
)
shared_worker_lock
=
kwargs
.
pop
(
"shared_worker_lock"
,
None
)
if
shared_worker_lock
is
None
:
msg
=
(
"Missing `shared_worker_lock` argument from executor. "
"This argument is needed for mm_processor_cache_type='shm'."
)
mm_config
=
self
.
vllm_config
.
model_config
.
multimodal_config
if
mm_config
and
mm_config
.
mm_processor_cache_type
==
"shm"
:
raise
ValueError
(
msg
)
else
:
logger
.
warning_once
(
msg
)
self
.
mm_receiver_cache
=
None
else
:
self
.
mm_receiver_cache
=
worker_receiver_cache_from_config
(
self
.
vllm_config
,
MULTIMODAL_REGISTRY
,
shared_worker_lock
,
)
with
set_current_vllm_config
(
self
.
vllm_config
):
with
set_current_vllm_config
(
self
.
vllm_config
):
# To make vLLM config available during worker initialization
# To make vLLM config available during worker initialization
self
.
worker
=
worker_class
(
**
kwargs
)
self
.
worker
=
worker_class
(
**
kwargs
)
...
@@ -323,5 +348,34 @@ class WorkerWrapperBase:
...
@@ -323,5 +348,34 @@ class WorkerWrapperBase:
logger
.
exception
(
msg
)
logger
.
exception
(
msg
)
raise
e
raise
e
def
__getattr__
(
self
,
attr
):
def
__getattr__
(
self
,
attr
:
str
):
return
getattr
(
self
.
worker
,
attr
)
return
getattr
(
self
.
worker
,
attr
)
def
_apply_mm_cache
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
mm_cache
=
self
.
mm_receiver_cache
if
mm_cache
is
None
:
return
for
req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_data
.
mm_features
=
mm_cache
.
get_and_update_features
(
req_data
.
mm_features
)
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
*
args
,
**
kwargs
,
)
->
ModelRunnerOutput
:
self
.
_apply_mm_cache
(
scheduler_output
)
assert
self
.
worker
is
not
None
return
self
.
worker
.
execute_model
(
scheduler_output
,
*
args
,
**
kwargs
)
def
reset_mm_cache
(
self
)
->
None
:
mm_receiver_cache
=
self
.
mm_receiver_cache
if
mm_receiver_cache
is
not
None
:
mm_receiver_cache
.
clear_cache
()
assert
self
.
worker
is
not
None
self
.
worker
.
reset_mm_cache
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment