Unverified Commit 22d9a056 authored by jma99_2333's avatar jma99_2333 Committed by GitHub
Browse files

Support clear mm and encoder cache (#33452)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 13b842f2
...@@ -172,6 +172,7 @@ These endpoints are **only available when the environment variable `VLLM_SERVER_ ...@@ -172,6 +172,7 @@ These endpoints are **only available when the environment variable `VLLM_SERVER_
- `/server_info` - Get detailed server configuration - `/server_info` - Get detailed server configuration
- `/reset_prefix_cache` - Reset prefix cache (can disrupt service) - `/reset_prefix_cache` - Reset prefix cache (can disrupt service)
- `/reset_mm_cache` - Reset multimodal cache (can disrupt service) - `/reset_mm_cache` - Reset multimodal cache (can disrupt service)
- `/reset_encoder_cache` - Reset encoder cache (can disrupt service)
- `/sleep` - Put engine to sleep (causes denial of service) - `/sleep` - Put engine to sleep (causes denial of service)
- `/wake_up` - Wake engine from sleep - `/wake_up` - Wake engine from sleep
- `/is_sleeping` - Check if engine is sleeping - `/is_sleeping` - Check if engine is sleeping
......
...@@ -4,7 +4,10 @@ import pytest ...@@ -4,7 +4,10 @@ import pytest
import torch import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
)
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
...@@ -247,3 +250,88 @@ def test_encoder_cache_mask_based_retrieval(): ...@@ -247,3 +250,88 @@ def test_encoder_cache_mask_based_retrieval():
assert num_embeds_before == 0 assert num_embeds_before == 0
assert num_embeds_in_range == 2 assert num_embeds_in_range == 2
def test_reset_clears_all_state():
"""Test that reset() clears all cached entries and restores capacity."""
manager = EncoderCacheManager(cache_size=20)
req1 = MockRequest("req1", ["img1", "img2"], [5, 3])
req2 = MockRequest("req2", ["img3"], [4])
manager.allocate(req1, 0)
manager.allocate(req1, 1)
manager.allocate(req2, 0)
manager.free_encoder_input(req1, 0)
req3 = MockRequest("req3", ["img4"], [10])
manager.free_encoder_input(req1, 1)
manager.free_encoder_input(req2, 0)
manager.can_allocate(req3, 0, int(1e9), 0)
manager.allocate(req3, 0)
assert len(manager.cached) > 0
assert manager.num_free_slots < 20
manager.reset()
assert len(manager.cached) == 0
assert len(manager.freeable) == 0
assert len(manager.freed) == 0
assert manager.num_free_slots == 20
assert manager.num_freeable_slots == 20
def test_reset_allows_fresh_allocations():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("req1", ["img1"], [10])
manager.allocate(req1, 0)
assert manager.num_free_slots == 0
manager.reset()
req2 = MockRequest("req2", ["img2"], [8])
assert manager.can_allocate(req2, 0, int(1e9), 0)
manager.allocate(req2, 0)
assert manager.num_free_slots == 2
assert "img2" in manager.cached
assert "img1" not in manager.cached
def test_encoder_decoder_cache_manager_reset():
manager = EncoderDecoderCacheManager(cache_size=20)
req1 = MockRequest("req1", ["img1"], [5])
req2 = MockRequest("req2", ["img2"], [3])
manager.allocate(req1, 0)
manager.allocate(req2, 0)
manager.free(req1)
manager.get_freed_mm_hashes()
assert manager.num_free_slots < 20
manager.reset()
assert len(manager.allocated) == 0
assert len(manager.to_free) == 0
assert manager.num_free_slots == 20
def test_encoder_decoder_cache_manager_reset_allows_fresh_allocations():
manager = EncoderDecoderCacheManager(cache_size=10)
req1 = MockRequest("req1", ["img1"], [10])
manager.allocate(req1, 0)
assert manager.num_free_slots == 0
manager.reset()
req2 = MockRequest("req2", ["img2"], [8])
assert manager.can_allocate(req2, 0, int(1e9), 0)
manager.allocate(req2, 0)
assert manager.num_free_slots == 2
assert "img2" in manager.allocated
...@@ -113,6 +113,11 @@ class EngineClient(ABC): ...@@ -113,6 +113,11 @@ class EngineClient(ABC):
"""Reset the multi-modal cache""" """Reset the multi-modal cache"""
... ...
@abstractmethod
async def reset_encoder_cache(self) -> None:
"""Reset the encoder cache"""
...
@abstractmethod @abstractmethod
async def reset_prefix_cache( async def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False self, reset_running_requests: bool = False, reset_connector: bool = False
......
...@@ -55,6 +55,17 @@ async def reset_mm_cache(raw_request: Request): ...@@ -55,6 +55,17 @@ async def reset_mm_cache(raw_request: Request):
return Response(status_code=200) return Response(status_code=200)
@router.post("/reset_encoder_cache")
async def reset_encoder_cache(raw_request: Request):
"""
Reset the encoder cache. Note that we currently do not check if the
encoder cache is successfully reset in the API server.
"""
logger.info("Resetting encoder cache...")
await engine_client(raw_request).reset_encoder_cache()
return Response(status_code=200)
def attach_router(app: FastAPI): def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE: if not envs.VLLM_SERVER_DEV_MODE:
return return
......
...@@ -77,6 +77,18 @@ class EncoderCacheManager: ...@@ -77,6 +77,18 @@ class EncoderCacheManager:
self.freeable: OrderedDict[str, int] = OrderedDict() self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = [] self.freed: list[str] = []
def reset(self) -> None:
"""Reset the encoder cache to its initial state.
This clears all cached encoder outputs and resets capacity tracking.
Called when model weights are updated to invalidate stale embeddings.
"""
self.cached.clear()
self.freeable.clear()
self.freed.clear()
self.num_free_slots = self.cache_size
self.num_freeable_slots = self.cache_size
def check_and_update_cache(self, request: Request, input_id: int) -> bool: def check_and_update_cache(self, request: Request, input_id: int) -> bool:
"""Check if encoder output for a specific multimodal input is cached. """Check if encoder output for a specific multimodal input is cached.
...@@ -360,6 +372,12 @@ class EncoderDecoderCacheManager(EncoderCacheManager): ...@@ -360,6 +372,12 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
self.allocated: list[str] = [] self.allocated: list[str] = []
self.to_free: list[str] = [] self.to_free: list[str] = []
def reset(self) -> None:
"""Reset the encoder cache to its initial state."""
self.num_free_slots = self.cache_size
self.allocated.clear()
self.to_free.clear()
def check_and_update_cache(self, request: Request, input_id: int) -> bool: def check_and_update_cache(self, request: Request, input_id: int) -> bool:
return False return False
......
...@@ -183,6 +183,15 @@ class SchedulerInterface(ABC): ...@@ -183,6 +183,15 @@ class SchedulerInterface(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.
This should be called when model weights are updated to ensure
stale vision embeddings are not reused.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_request_counts(self) -> tuple[int, int]: def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs).""" """Returns (num_running_reqs, num_waiting_reqs)."""
......
...@@ -1763,6 +1763,14 @@ class Scheduler(SchedulerInterface): ...@@ -1763,6 +1763,14 @@ class Scheduler(SchedulerInterface):
return True return True
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.
This should be called when model weights are updated to ensure
stale vision embeddings are not reused.
"""
self.encoder_cache_manager.reset()
def make_stats( def make_stats(
self, self,
spec_decoding_stats: SpecDecodingStats | None = None, spec_decoding_stats: SpecDecodingStats | None = None,
...@@ -1788,6 +1796,7 @@ class Scheduler(SchedulerInterface): ...@@ -1788,6 +1796,7 @@ class Scheduler(SchedulerInterface):
num_running_reqs=len(self.running), num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting),
kv_cache_usage=self.kv_cache_manager.usage, kv_cache_usage=self.kv_cache_manager.usage,
encoder_cache_usage=self._get_encoder_cache_usage(),
prefix_cache_stats=prefix_cache_stats, prefix_cache_stats=prefix_cache_stats,
connector_prefix_cache_stats=connector_prefix_cache_stats, connector_prefix_cache_stats=connector_prefix_cache_stats,
kv_cache_eviction_events=eviction_events, kv_cache_eviction_events=eviction_events,
...@@ -1797,6 +1806,14 @@ class Scheduler(SchedulerInterface): ...@@ -1797,6 +1806,14 @@ class Scheduler(SchedulerInterface):
perf_stats=perf_stats, perf_stats=perf_stats,
) )
def _get_encoder_cache_usage(self) -> float:
"""Get encoder cache usage as a fraction (0.0 to 1.0)."""
ecm = self.encoder_cache_manager
if ecm.cache_size == 0:
return 0.0
used_slots = ecm.cache_size - ecm.num_free_slots
return used_slots / ecm.cache_size
def make_spec_decoding_stats( def make_spec_decoding_stats(
self, self,
spec_decoding_stats: SpecDecodingStats | None, spec_decoding_stats: SpecDecodingStats | None,
......
...@@ -882,6 +882,9 @@ class AsyncLLM(EngineClient): ...@@ -882,6 +882,9 @@ class AsyncLLM(EngineClient):
reset_running_requests, reset_connector reset_running_requests, reset_connector
) )
async def reset_encoder_cache(self) -> None:
await self.engine_core.reset_encoder_cache_async()
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache() await self.reset_prefix_cache()
await self.engine_core.sleep_async(level) await self.engine_core.sleep_async(level)
......
...@@ -565,6 +565,26 @@ class EngineCore: ...@@ -565,6 +565,26 @@ class EngineCore:
reset_running_requests, reset_connector reset_running_requests, reset_connector
) )
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.
This should be called when model weights are updated to ensure
stale vision embeddings computed with old weights are not reused.
Clears both the scheduler's cache manager and the GPU model runner's cache.
"""
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 sender, P1 receiver)
if self.scheduler.has_unfinished_requests():
logger.warning(
"Resetting the encoder cache when requests are "
"in progress may lead to desynced internal caches."
)
# Reset the scheduler's encoder cache manager (logical state)
self.scheduler.reset_encoder_cache()
# Reset the GPU model runner's encoder cache (physical storage)
self.model_executor.reset_encoder_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.model_executor.sleep(level) self.model_executor.sleep(level)
......
...@@ -144,6 +144,9 @@ class EngineCoreClient(ABC): ...@@ -144,6 +144,9 @@ class EngineCoreClient(ABC):
) -> bool: ) -> bool:
raise NotImplementedError raise NotImplementedError
def reset_encoder_cache(self) -> None:
raise NotImplementedError
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -216,6 +219,9 @@ class EngineCoreClient(ABC): ...@@ -216,6 +219,9 @@ class EngineCoreClient(ABC):
) -> bool: ) -> bool:
raise NotImplementedError raise NotImplementedError
async def reset_encoder_cache_async(self) -> None:
raise NotImplementedError
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -300,6 +306,9 @@ class InprocClient(EngineCoreClient): ...@@ -300,6 +306,9 @@ class InprocClient(EngineCoreClient):
reset_running_requests, reset_connector reset_running_requests, reset_connector
) )
def reset_encoder_cache(self) -> None:
self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level) self.engine_core.sleep(level)
...@@ -765,6 +774,9 @@ class SyncMPClient(MPClient): ...@@ -765,6 +774,9 @@ class SyncMPClient(MPClient):
"reset_prefix_cache", reset_running_requests, reset_connector "reset_prefix_cache", reset_running_requests, reset_connector
) )
def reset_encoder_cache(self) -> None:
self.call_utility("reset_encoder_cache")
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.call_utility("add_lora", lora_request) return self.call_utility("add_lora", lora_request)
...@@ -973,6 +985,9 @@ class AsyncMPClient(MPClient): ...@@ -973,6 +985,9 @@ class AsyncMPClient(MPClient):
"reset_prefix_cache", reset_running_requests, reset_connector "reset_prefix_cache", reset_running_requests, reset_connector
) )
async def reset_encoder_cache_async(self) -> None:
await self.call_utility_async("reset_encoder_cache")
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
await self.call_utility_async("sleep", level) await self.call_utility_async("sleep", level)
......
...@@ -332,6 +332,14 @@ class LLMEngine: ...@@ -332,6 +332,14 @@ class LLMEngine:
reset_running_requests, reset_connector reset_running_requests, reset_connector
) )
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.
This should be called when model weights are updated to ensure
stale vision embeddings computed with old weights are not reused.
"""
self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.engine_core.sleep(level) self.engine_core.sleep(level)
......
...@@ -294,6 +294,10 @@ class Executor(ABC): ...@@ -294,6 +294,10 @@ class Executor(ABC):
"""Reset the multi-modal cache in each worker.""" """Reset the multi-modal cache in each worker."""
self.collective_rpc("reset_mm_cache") self.collective_rpc("reset_mm_cache")
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache in each worker to clear cached encoder outputs."""
self.collective_rpc("reset_encoder_cache")
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
if self.is_sleeping: if self.is_sleeping:
logger.warning("Executor is already sleeping.") logger.warning("Executor is already sleeping.")
......
...@@ -173,6 +173,7 @@ class SchedulerStats: ...@@ -173,6 +173,7 @@ class SchedulerStats:
current_wave: int = 0 current_wave: int = 0
kv_cache_usage: float = 0.0 kv_cache_usage: float = 0.0
encoder_cache_usage: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
connector_prefix_cache_stats: PrefixCacheStats | None = None connector_prefix_cache_stats: PrefixCacheStats | None = None
......
...@@ -720,6 +720,14 @@ class GPUModelRunner( ...@@ -720,6 +720,14 @@ class GPUModelRunner(
if self.mm_budget: if self.mm_budget:
self.mm_budget.reset_cache() self.mm_budget.reset_cache()
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
@torch.inference_mode() @torch.inference_mode()
def init_fp8_kv_scales(self) -> None: def init_fp8_kv_scales(self) -> None:
""" """
......
...@@ -539,6 +539,9 @@ class Worker(WorkerBase): ...@@ -539,6 +539,9 @@ class Worker(WorkerBase):
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache() self.model_runner.reset_mm_cache()
def reset_encoder_cache(self) -> None:
self.model_runner.reset_encoder_cache()
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment