Unverified Commit dd6ac1c2 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[RL] [V1] Remove unused device argument from reset_kv_cache (#28766)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent 98b4d389
...@@ -125,7 +125,7 @@ class EngineClient(ABC): ...@@ -125,7 +125,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def reset_prefix_cache(self, device: Device | None = None) -> None: async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache""" """Reset the prefix cache"""
... ...
......
...@@ -32,7 +32,6 @@ from vllm.config.model import ( ...@@ -32,7 +32,6 @@ from vllm.config.model import (
TokenizerMode, TokenizerMode,
) )
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.protocol import Device
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
...@@ -1499,8 +1498,8 @@ class LLM: ...@@ -1499,8 +1498,8 @@ class LLM:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.llm_engine.stop_profile() self.llm_engine.stop_profile()
def reset_prefix_cache(self, device: Device | None = None) -> None: def reset_prefix_cache(self) -> None:
self.llm_engine.reset_prefix_cache(device) self.llm_engine.reset_prefix_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
""" """
......
...@@ -39,7 +39,7 @@ from typing_extensions import assert_never ...@@ -39,7 +39,7 @@ from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import Device, EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import ( from vllm.entrypoints.anthropic.protocol import (
AnthropicError, AnthropicError,
AnthropicErrorResponse, AnthropicErrorResponse,
...@@ -1069,12 +1069,8 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -1069,12 +1069,8 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server. prefix cache is successfully reset in the API server.
""" """
device = None logger.info("Resetting prefix cache...")
device_str = raw_request.query_params.get("device") await engine_client(raw_request).reset_prefix_cache()
if device_str is not None:
device = Device[device_str.upper()]
logger.info("Resetting prefix cache with specific %s...", str(device))
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200) return Response(status_code=200)
@router.post("/reset_mm_cache") @router.post("/reset_mm_cache")
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import Device, EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -672,9 +672,7 @@ class AsyncLLM(EngineClient): ...@@ -672,9 +672,7 @@ class AsyncLLM(EngineClient):
self.processor.clear_mm_cache() self.processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async() await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, device: Device | None = None) -> None: async def reset_prefix_cache(self) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async() await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
......
...@@ -14,7 +14,6 @@ from vllm.config import ParallelConfig, VllmConfig ...@@ -14,7 +14,6 @@ from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.protocol import Device
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -321,7 +320,7 @@ class LLMEngine: ...@@ -321,7 +320,7 @@ class LLMEngine:
self.processor.clear_mm_cache() self.processor.clear_mm_cache()
self.engine_core.reset_mm_cache() self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Device | None = None): def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache() self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment