Unverified Commit 61e0a506 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Avoid repeatedly creating dummy data during engine startup (#17935)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1df491c5
...@@ -1232,6 +1232,9 @@ class AsyncLLMEngine(EngineClient): ...@@ -1232,6 +1232,9 @@ class AsyncLLMEngine(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
self.engine.stop_profile() self.engine.stop_profile()
async def reset_mm_cache(self) -> None:
self.engine.reset_mm_cache()
async def reset_prefix_cache(self, async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None: device: Optional[Device] = None) -> None:
self.engine.reset_prefix_cache(device) self.engine.reset_prefix_cache(device)
......
...@@ -409,6 +409,9 @@ class LLMEngine: ...@@ -409,6 +409,9 @@ class LLMEngine:
# the next step without re-scheduling. # the next step without re-scheduling.
self._skip_scheduling_next_step = False self._skip_scheduling_next_step = False
# Don't keep the dummy data in memory
self.reset_mm_cache()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
...@@ -913,6 +916,10 @@ class LLMEngine: ...@@ -913,6 +916,10 @@ class LLMEngine:
""" """
return self.scheduler[virtual_engine].has_unfinished_seqs() return self.scheduler[virtual_engine].has_unfinished_seqs()
def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache."""
return self.input_preprocessor.mm_registry.reset_processor_cache()
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for all devices."""
......
...@@ -123,6 +123,10 @@ class RPCUProfileRequest(Enum): ...@@ -123,6 +123,10 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2 STOP_PROFILE = 2
class RPCResetMultiModalCacheRequest(Enum):
RESET = 1
@dataclass @dataclass
class RPCResetPrefixCacheRequest: class RPCResetPrefixCacheRequest:
device: Device device: Device
...@@ -164,6 +168,7 @@ class RPCAdapterLoadedResponse: ...@@ -164,6 +168,7 @@ class RPCAdapterLoadedResponse:
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest, RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest, RPCSleepRequest, RPCResetPrefixCacheRequest, RPCSleepRequest,
RPCWakeUpRequest, RPCIsSleepingRequest] RPCWakeUpRequest, RPCIsSleepingRequest]
......
...@@ -31,6 +31,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -31,6 +31,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCIsSleepingResponse, RPCIsSleepingResponse,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest, RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest, RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse, RPCStartupResponse,
...@@ -687,6 +688,13 @@ class MQLLMEngineClient(EngineClient): ...@@ -687,6 +688,13 @@ class MQLLMEngineClient(EngineClient):
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache"""
await self._send_one_way_rpc_request(
request=RPCResetMultiModalCacheRequest.RESET,
socket=self.input_socket)
async def reset_prefix_cache(self, async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None: device: Optional[Device] = None) -> None:
"""Reset the prefix cache""" """Reset the prefix cache"""
......
...@@ -22,6 +22,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -22,6 +22,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCIsSleepingResponse, RPCIsSleepingResponse,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest, RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest, RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse, RPCStartupResponse,
...@@ -269,6 +270,8 @@ class MQLLMEngine: ...@@ -269,6 +270,8 @@ class MQLLMEngine:
self.stop_profile() self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest): elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request) self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetMultiModalCacheRequest):
self.reset_mm_cache()
elif isinstance(request, RPCResetPrefixCacheRequest): elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache() self.reset_prefix_cache()
elif isinstance(request, RPCSleepRequest): elif isinstance(request, RPCSleepRequest):
...@@ -409,6 +412,9 @@ class MQLLMEngine: ...@@ -409,6 +412,9 @@ class MQLLMEngine:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.engine.stop_profile() self.engine.stop_profile()
def reset_mm_cache(self) -> bool:
return self.engine.reset_mm_cache()
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache() return self.engine.reset_prefix_cache()
......
...@@ -278,6 +278,11 @@ class EngineClient(ABC): ...@@ -278,6 +278,11 @@ class EngineClient(ABC):
"""Start profiling the engine""" """Start profiling the engine"""
... ...
@abstractmethod
async def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache"""
...
@abstractmethod @abstractmethod
async def reset_prefix_cache(self, async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None: device: Optional[Device] = None) -> None:
......
...@@ -150,6 +150,10 @@ async def build_async_engine_client( ...@@ -150,6 +150,10 @@ async def build_async_engine_client(
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine: engine_args, args.disable_frontend_multiprocessing) as engine:
# Don't keep the dummy data in memory
await engine.reset_mm_cache()
yield engine yield engine
......
...@@ -1026,6 +1026,11 @@ class ProcessingCache: ...@@ -1026,6 +1026,11 @@ class ProcessingCache:
def put_item(self, item: ProcessingCacheItem) -> None: def put_item(self, item: ProcessingCacheItem) -> None:
self._cache[item.key] = item.value self._cache[item.key] = item.value
def reset(self) -> bool:
self._cache.clear()
return True
class BaseProcessingInfo: class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing.""" """Base class to provide the information necessary for data processing."""
......
...@@ -88,6 +88,12 @@ class MultiModalRegistry: ...@@ -88,6 +88,12 @@ class MultiModalRegistry:
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
def reset_processor_cache(self) -> bool:
"""Reset the multi-modal processing cache."""
self._processing_cache.reset()
return True # Success
@deprecated("Legacy input processor/mapper pipeline has been removed. " @deprecated("Legacy input processor/mapper pipeline has been removed. "
"Please update your model runner to use " "Please update your model runner to use "
"`seq_group_metadata.multi_modal_data` directly without " "`seq_group_metadata.multi_modal_data` directly without "
...@@ -106,7 +112,7 @@ class MultiModalRegistry: ...@@ -106,7 +112,7 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return {} return {}
processor = self.create_processor(model_config, disable_cache=True) processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
...@@ -190,7 +196,7 @@ class MultiModalRegistry: ...@@ -190,7 +196,7 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return {} return {}
processor = self.create_processor(model_config, disable_cache=True) processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits() return profiler.get_mm_limits()
...@@ -286,7 +292,7 @@ class MultiModalRegistry: ...@@ -286,7 +292,7 @@ class MultiModalRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
""" """
processor = self.create_processor(model_config, disable_cache=True) processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
...@@ -310,7 +316,7 @@ class MultiModalRegistry: ...@@ -310,7 +316,7 @@ class MultiModalRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
""" """
processor = self.create_processor(model_config, disable_cache=True) processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
......
...@@ -476,6 +476,11 @@ class AsyncLLM(EngineClient): ...@@ -476,6 +476,11 @@ class AsyncLLM(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
await self.engine_core.profile_async(False) await self.engine_core.profile_async(False)
async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None: device: Optional[Device] = None) -> None:
if device == Device.CPU: if device == Device.CPU:
......
...@@ -286,6 +286,15 @@ class EngineCore: ...@@ -286,6 +286,15 @@ class EngineCore:
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
self.model_executor.profile(is_start) self.model_executor.profile(is_start)
def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
if self.scheduler.get_num_unfinished_requests():
logger.warning("Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches.")
self.mm_input_cache_server.reset()
def reset_prefix_cache(self): def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache() self.scheduler.reset_prefix_cache()
......
...@@ -88,6 +88,9 @@ class EngineCoreClient(ABC): ...@@ -88,6 +88,9 @@ class EngineCoreClient(ABC):
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
raise NotImplementedError raise NotImplementedError
def reset_mm_cache(self) -> None:
raise NotImplementedError
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -143,6 +146,9 @@ class EngineCoreClient(ABC): ...@@ -143,6 +146,9 @@ class EngineCoreClient(ABC):
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError raise NotImplementedError
async def reset_mm_cache_async(self) -> None:
raise NotImplementedError
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -214,6 +220,9 @@ class InprocClient(EngineCoreClient): ...@@ -214,6 +220,9 @@ class InprocClient(EngineCoreClient):
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start) self.engine_core.profile(is_start)
def reset_mm_cache(self) -> None:
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache() self.engine_core.reset_prefix_cache()
...@@ -600,6 +609,9 @@ class SyncMPClient(MPClient): ...@@ -600,6 +609,9 @@ class SyncMPClient(MPClient):
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self.call_utility("profile", is_start) self.call_utility("profile", is_start)
def reset_mm_cache(self) -> None:
self.call_utility("reset_mm_cache")
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
self.call_utility("reset_prefix_cache") self.call_utility("reset_prefix_cache")
...@@ -787,6 +799,9 @@ class AsyncMPClient(MPClient): ...@@ -787,6 +799,9 @@ class AsyncMPClient(MPClient):
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self.call_utility_async("profile", is_start) await self.call_utility_async("profile", is_start)
async def reset_mm_cache_async(self) -> None:
await self.call_utility_async("reset_mm_cache")
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
await self.call_utility_async("reset_prefix_cache") await self.call_utility_async("reset_prefix_cache")
......
...@@ -101,6 +101,9 @@ class LLMEngine: ...@@ -101,6 +101,9 @@ class LLMEngine:
# for v0 compatibility # for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
# Don't keep the dummy data in memory
self.reset_mm_cache()
@classmethod @classmethod
def from_vllm_config( def from_vllm_config(
cls, cls,
...@@ -240,6 +243,11 @@ class LLMEngine: ...@@ -240,6 +243,11 @@ class LLMEngine:
def stop_profile(self): def stop_profile(self):
self.engine_core.profile(False) self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None): def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache() self.engine_core.reset_prefix_cache()
......
...@@ -83,3 +83,8 @@ class MirroredProcessingCache: ...@@ -83,3 +83,8 @@ class MirroredProcessingCache:
full_mm_inputs.append(mm_input) full_mm_inputs.append(mm_input)
return full_mm_inputs return full_mm_inputs
def reset(self) -> bool:
self.mm_cache.clear()
return True
...@@ -54,6 +54,10 @@ class Processor: ...@@ -54,6 +54,10 @@ class Processor:
self.use_hash = self.mm_input_cache_client.use_cache or \ self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching self.cache_config.enable_prefix_caching
@property
def mm_registry(self):
return self.input_preprocessor.mm_registry
def _validate_logprobs( def _validate_logprobs(
self, self,
params: SamplingParams, params: SamplingParams,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment