Unverified Commit ad430a67 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Metrics] Log multi-modal cache stats and fix reset (#26285)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6f0f570c
......@@ -442,6 +442,9 @@ class Worker(WorkerBase):
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
......
......@@ -371,6 +371,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
self.sample_from_logits_func = self.sample_from_logits
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
......
......@@ -293,6 +293,9 @@ class TPUWorker:
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
......
......@@ -126,6 +126,10 @@ class MultiModalBudget:
return max_items_per_prompt, max_items_per_batch
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()
@dataclass
class AttentionGroup:
......
......@@ -4,7 +4,7 @@
from __future__ import annotations
import os
from typing import Any, Callable, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
import torch
import torch.nn as nn
......@@ -12,7 +12,8 @@ import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (
enable_trace_function_call_for_thread,
resolve_obj_by_qualname,
......@@ -21,7 +22,10 @@ from vllm.utils import (
warn_for_unimplemented_methods,
)
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.outputs import SamplerOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
......@@ -103,6 +107,11 @@ class WorkerBase:
"""Initialize the KV cache with the given size in blocks."""
raise NotImplementedError
def reset_mm_cache(self) -> None:
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
if callable(reset_fn):
reset_fn()
def get_model(self) -> nn.Module:
raise NotImplementedError
......@@ -114,9 +123,7 @@ class WorkerBase:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(
self, execute_model_req: ExecuteModelRequest | None = None
) -> list[SamplerOutput] | None:
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
raise NotImplementedError
def start_worker_execution_loop(self) -> None:
......@@ -125,11 +132,7 @@ class WorkerBase:
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
with self.current_platform.inference_mode():
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
raise NotImplementedError("Dead V0 code")
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
......@@ -289,6 +292,28 @@ class WorkerWrapperBase:
worker_class,
extended_calls,
)
shared_worker_lock = kwargs.pop("shared_worker_lock", None)
if shared_worker_lock is None:
msg = (
"Missing `shared_worker_lock` argument from executor. "
"This argument is needed for mm_processor_cache_type='shm'."
)
mm_config = self.vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_processor_cache_type == "shm":
raise ValueError(msg)
else:
logger.warning_once(msg)
self.mm_receiver_cache = None
else:
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config,
MULTIMODAL_REGISTRY,
shared_worker_lock,
)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
......@@ -323,5 +348,34 @@ class WorkerWrapperBase:
logger.exception(msg)
raise e
def __getattr__(self, attr):
def __getattr__(self, attr: str):
return getattr(self.worker, attr)
def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
mm_cache = self.mm_receiver_cache
if mm_cache is None:
return
for req_data in scheduler_output.scheduled_new_reqs:
req_data.mm_features = mm_cache.get_and_update_features(
req_data.mm_features
)
def execute_model(
self,
scheduler_output: SchedulerOutput,
*args,
**kwargs,
) -> ModelRunnerOutput:
self._apply_mm_cache(scheduler_output)
assert self.worker is not None
return self.worker.execute_model(scheduler_output, *args, **kwargs)
def reset_mm_cache(self) -> None:
mm_receiver_cache = self.mm_receiver_cache
if mm_receiver_cache is not None:
mm_receiver_cache.clear_cache()
assert self.worker is not None
self.worker.reset_mm_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment