Unverified Commit 79b6ec6a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix inconsistent handling of cache reset (#33481)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d6416fdd
...@@ -82,7 +82,7 @@ vllm bench sweep serve \ ...@@ -82,7 +82,7 @@ vllm bench sweep serve \
You can use `--dry-run` to preview the commands to be run. You can use `--dry-run` to preview the commands to be run.
We only start the server once for each `--serve-params`, and keep it running for multiple `--bench-params`. We only start the server once for each `--serve-params`, and keep it running for multiple `--bench-params`.
Between each benchmark run, we call the `/reset_prefix_cache` and `/reset_mm_cache` endpoints to get a clean slate for the next run. Between each benchmark run, we call all `/reset_*_cache` endpoints to get a clean slate for the next run.
In case you are using a custom `--serve-cmd`, you can override the commands used for resetting the state by setting `--after-bench-cmd`. In case you are using a custom `--serve-cmd`, you can override the commands used for resetting the state by setting `--after-bench-cmd`.
!!! note !!! note
......
...@@ -12,6 +12,12 @@ from typing_extensions import Self ...@@ -12,6 +12,12 @@ from typing_extensions import Self
class ServerProcess: class ServerProcess:
VLLM_RESET_CACHE_ENDPOINTS = [
"/reset_prefix_cache",
"/reset_mm_cache",
"/reset_encoder_cache",
]
def __init__( def __init__(
self, self,
server_cmd: list[str], server_cmd: list[str],
...@@ -120,11 +126,9 @@ class ServerProcess: ...@@ -120,11 +126,9 @@ class ServerProcess:
server_address = self._get_vllm_server_address() server_address = self._get_vllm_server_address()
print(f"Resetting caches at {server_address}") print(f"Resetting caches at {server_address}")
res = requests.post(f"{server_address}/reset_prefix_cache") for endpoint in self.VLLM_RESET_CACHE_ENDPOINTS:
res.raise_for_status() res = requests.post(server_address + endpoint)
res.raise_for_status()
res = requests.post(f"{server_address}/reset_mm_cache")
res.raise_for_status()
elif server_cmd[0].endswith("infinity_emb"): elif server_cmd[0].endswith("infinity_emb"):
if "--vector-disk-cache" in server_cmd: if "--vector-disk-cache" in server_cmd:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -286,10 +286,6 @@ class OpenAIServing: ...@@ -286,10 +286,6 @@ class OpenAIServing:
raise TypeError(f"{reasoning_parser_name=} has not been registered") from e raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
return parser return parser
async def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache()
await self.engine_client.reset_mm_cache()
async def beam_search( async def beam_search(
self, self,
prompt: PromptType, prompt: PromptType,
......
...@@ -741,6 +741,7 @@ class AsyncLLM(EngineClient): ...@@ -741,6 +741,7 @@ class AsyncLLM(EngineClient):
if clear_cache: if clear_cache:
await self.reset_prefix_cache() await self.reset_prefix_cache()
await self.reset_mm_cache() await self.reset_mm_cache()
await self.reset_encoder_cache()
async def resume_generation(self) -> None: async def resume_generation(self) -> None:
"""Resume generation after :meth:`pause_generation`.""" """Resume generation after :meth:`pause_generation`."""
......
...@@ -31,6 +31,22 @@ class EncoderRunner: ...@@ -31,6 +31,22 @@ class EncoderRunner:
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {} self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {} self.encoder_cache: dict[str, torch.Tensor] = {}
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]): def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features self.req_id_to_mm_features[req_id] = mm_features
......
...@@ -339,7 +339,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -339,7 +339,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
gc.collect() gc.collect()
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
pass self.encoder_runner.reset_mm_cache()
def reset_encoder_cache(self) -> None:
self.encoder_runner.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet. # SP is not supported yet.
......
...@@ -717,6 +717,10 @@ class GPUModelRunner( ...@@ -717,6 +717,10 @@ class GPUModelRunner(
self.effective_drafter_max_model_len = self.max_model_len self.effective_drafter_max_model_len = self.max_model_len
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
if self.mm_budget: if self.mm_budget:
self.mm_budget.reset_cache() self.mm_budget.reset_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment