Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
26dd972a
Unverified
Commit
26dd972a
authored
Mar 20, 2025
by
maobaolong
Committed by
GitHub
Mar 19, 2025
Browse files
[FEAT]Support reset prefix cache by specified device (#15003)
parent
61c7a1b8
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
49 additions
and
34 deletions
+49
-34
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+4
-2
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+1
-1
vllm/core/block_manager.py
vllm/core/block_manager.py
+2
-2
vllm/core/interfaces.py
vllm/core/interfaces.py
+3
-3
vllm/core/placeholder_block_space_manager.py
vllm/core/placeholder_block_space_manager.py
+2
-2
vllm/core/scheduler.py
vllm/core/scheduler.py
+2
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+4
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-2
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+4
-3
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+4
-3
vllm/engine/protocol.py
vllm/engine/protocol.py
+3
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-3
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+7
-3
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+5
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+2
-1
No files found.
vllm/core/block/cpu_gpu_block_allocator.py
View file @
26dd972a
...
@@ -341,8 +341,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -341,8 +341,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
assert
device
in
self
.
_allocators
assert
device
in
self
.
_allocators
return
self
.
_allocators
[
device
].
get_prefix_cache_hit_rate
()
return
self
.
_allocators
[
device
].
get_prefix_cache_hit_rate
()
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
"""Reset prefix cache for all devices."""
"""Reset prefix cache for specified or all devices."""
if
device
:
return
self
.
_allocators
[
device
].
reset_prefix_cache
()
success
=
True
success
=
True
for
allocator
in
self
.
_allocators
.
values
():
for
allocator
in
self
.
_allocators
.
values
():
success
=
success
and
allocator
.
reset_prefix_cache
()
success
=
success
and
allocator
.
reset_prefix_cache
()
...
...
vllm/core/block/interfaces.py
View file @
26dd972a
...
@@ -305,7 +305,7 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -305,7 +305,7 @@ class DeviceAwareBlockAllocator(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
"""Reset prefix cache."""
"""Reset prefix cache."""
pass
pass
...
...
vllm/core/block_manager.py
View file @
26dd972a
...
@@ -456,8 +456,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
...
@@ -456,8 +456,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
return
self
.
block_allocator
.
get_prefix_cache_hit_rate
(
device
)
return
self
.
block_allocator
.
get_prefix_cache_hit_rate
(
device
)
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
return
self
.
block_allocator
.
reset_prefix_cache
()
return
self
.
block_allocator
.
reset_prefix_cache
(
device
)
def
_can_swap
(
self
,
def
_can_swap
(
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
...
...
vllm/core/interfaces.py
View file @
26dd972a
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
enum
import
enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
from
typing
import
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
typing
import
Tuple
...
@@ -125,8 +125,8 @@ class BlockSpaceManager(ABC):
...
@@ -125,8 +125,8 @@ class BlockSpaceManager(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
"""Reset prefix cache for all devices."""
"""Reset prefix cache for
specified or
all devices."""
pass
pass
@
abstractmethod
@
abstractmethod
...
...
vllm/core/placeholder_block_space_manager.py
View file @
26dd972a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.sequence
import
Sequence
,
SequenceGroup
from
vllm.sequence
import
Sequence
,
SequenceGroup
...
@@ -92,7 +92,7 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
...
@@ -92,7 +92,7 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
return
-
1
return
-
1
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
return
True
return
True
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
...
...
vllm/core/scheduler.py
View file @
26dd972a
...
@@ -634,8 +634,8 @@ class Scheduler:
...
@@ -634,8 +634,8 @@ class Scheduler:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
return
self
.
block_manager
.
get_prefix_cache_hit_rate
(
device
)
return
self
.
block_manager
.
get_prefix_cache_hit_rate
(
device
)
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
return
self
.
block_manager
.
reset_prefix_cache
()
return
self
.
block_manager
.
reset_prefix_cache
(
device
)
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
...
...
vllm/engine/async_llm_engine.py
View file @
26dd972a
...
@@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
deprecate_kwargs
,
weak_bind
from
vllm.utils
import
Device
,
deprecate_kwargs
,
weak_bind
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
@@ -1216,8 +1216,9 @@ class AsyncLLMEngine(EngineClient):
...
@@ -1216,8 +1216,9 @@ class AsyncLLMEngine(EngineClient):
async
def
stop_profile
(
self
)
->
None
:
async
def
stop_profile
(
self
)
->
None
:
self
.
engine
.
stop_profile
()
self
.
engine
.
stop_profile
()
async
def
reset_prefix_cache
(
self
)
->
None
:
async
def
reset_prefix_cache
(
self
,
self
.
engine
.
reset_prefix_cache
()
device
:
Optional
[
Device
]
=
None
)
->
None
:
self
.
engine
.
reset_prefix_cache
(
device
)
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
engine
.
sleep
(
level
)
self
.
engine
.
sleep
(
level
)
...
...
vllm/engine/llm_engine.py
View file @
26dd972a
...
@@ -955,12 +955,12 @@ class LLMEngine:
...
@@ -955,12 +955,12 @@ class LLMEngine:
"""
"""
return
self
.
scheduler
[
virtual_engine
].
has_unfinished_seqs
()
return
self
.
scheduler
[
virtual_engine
].
has_unfinished_seqs
()
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
"""Reset prefix cache for all devices."""
"""Reset prefix cache for all devices."""
success
=
True
success
=
True
for
scheduler
in
self
.
scheduler
:
for
scheduler
in
self
.
scheduler
:
success
=
success
and
scheduler
.
reset_prefix_cache
()
success
=
success
and
scheduler
.
reset_prefix_cache
(
device
)
return
success
return
success
@
staticmethod
@
staticmethod
...
...
vllm/engine/multiprocessing/__init__.py
View file @
26dd972a
...
@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
...
@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
deprecate_kwargs
from
vllm.utils
import
Device
,
deprecate_kwargs
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
...
@@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
...
@@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE
=
2
STOP_PROFILE
=
2
class
RPCResetPrefixCacheRequest
(
Enum
):
@
dataclass
RESET_PREFIX_CACHE
=
1
class
RPCResetPrefixCacheRequest
:
device
:
Device
class
RPCSleepRequest
(
Enum
):
class
RPCSleepRequest
(
Enum
):
...
...
vllm/engine/multiprocessing/client.py
View file @
26dd972a
...
@@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
...
@@ -47,7 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
deprecate_kwargs
from
vllm.utils
import
Device
,
deprecate_kwargs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient):
...
@@ -684,11 +684,12 @@ class MQLLMEngineClient(EngineClient):
await
self
.
_send_one_way_rpc_request
(
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
request
=
RPCUProfileRequest
.
STOP_PROFILE
,
socket
=
self
.
input_socket
)
async
def
reset_prefix_cache
(
self
)
->
None
:
async
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
None
:
"""Reset the prefix cache"""
"""Reset the prefix cache"""
await
self
.
_send_one_way_rpc_request
(
await
self
.
_send_one_way_rpc_request
(
request
=
RPCResetPrefixCacheRequest
.
RESET_PREFIX_CACHE
,
request
=
RPCResetPrefixCacheRequest
(
device
)
,
socket
=
self
.
input_socket
)
socket
=
self
.
input_socket
)
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
...
...
vllm/engine/protocol.py
View file @
26dd972a
...
@@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
...
@@ -18,7 +18,7 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
collect_from_async_generator
,
random_uuid
from
vllm.utils
import
Device
,
collect_from_async_generator
,
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -274,7 +274,8 @@ class EngineClient(ABC):
...
@@ -274,7 +274,8 @@ class EngineClient(ABC):
...
...
@
abstractmethod
@
abstractmethod
async
def
reset_prefix_cache
(
self
)
->
None
:
async
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
None
:
"""Reset the prefix cache"""
"""Reset the prefix cache"""
...
...
...
...
vllm/entrypoints/llm.py
View file @
26dd972a
...
@@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
...
@@ -42,7 +42,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer
)
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -1187,8 +1188,8 @@ class LLM:
...
@@ -1187,8 +1188,8 @@ class LLM:
def
stop_profile
(
self
)
->
None
:
def
stop_profile
(
self
)
->
None
:
self
.
llm_engine
.
stop_profile
()
self
.
llm_engine
.
stop_profile
()
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
bool
:
return
self
.
llm_engine
.
reset_prefix_cache
()
return
self
.
llm_engine
.
reset_prefix_cache
(
device
)
def
sleep
(
self
,
level
:
int
=
1
):
def
sleep
(
self
,
level
:
int
=
1
):
"""
"""
...
...
vllm/entrypoints/openai/api_server.py
View file @
26dd972a
...
@@ -85,7 +85,7 @@ from vllm.logger import init_logger
...
@@ -85,7 +85,7 @@ from vllm.logger import init_logger
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
maybe_register_config_serialize_by_value
)
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
from
vllm.utils
import
(
Device
,
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
is_valid_ipv6_address
,
set_ulimit
)
is_valid_ipv6_address
,
set_ulimit
)
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
...
@@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE:
...
@@ -677,8 +677,12 @@ if envs.VLLM_SERVER_DEV_MODE:
Reset the prefix cache. Note that we currently do not check if the
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
prefix cache is successfully reset in the API server.
"""
"""
logger
.
info
(
"Resetting prefix cache..."
)
device
=
None
await
engine_client
(
raw_request
).
reset_prefix_cache
()
device_str
=
raw_request
.
query_params
.
get
(
"device"
)
if
device_str
is
not
None
:
device
=
Device
[
device_str
.
upper
()]
logger
.
info
(
"Resetting prefix cache with specific %s..."
,
str
(
device
))
await
engine_client
(
raw_request
).
reset_prefix_cache
(
device
)
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/sleep"
)
@
router
.
post
(
"/sleep"
)
...
...
vllm/v1/engine/async_llm.py
View file @
26dd972a
...
@@ -24,7 +24,7 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams
...
@@ -24,7 +24,7 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
cdiv
,
kill_process_tree
from
vllm.utils
import
Device
,
cdiv
,
kill_process_tree
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
...
@@ -398,7 +398,10 @@ class AsyncLLM(EngineClient):
...
@@ -398,7 +398,10 @@ class AsyncLLM(EngineClient):
async
def
stop_profile
(
self
)
->
None
:
async
def
stop_profile
(
self
)
->
None
:
await
self
.
engine_core
.
profile_async
(
False
)
await
self
.
engine_core
.
profile_async
(
False
)
async
def
reset_prefix_cache
(
self
)
->
None
:
async
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
None
:
if
device
==
Device
.
CPU
:
raise
ValueError
(
"Not supported on CPU."
)
await
self
.
engine_core
.
reset_prefix_cache_async
()
await
self
.
engine_core
.
reset_prefix_cache_async
()
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
...
...
vllm/v1/engine/llm_engine.py
View file @
26dd972a
...
@@ -20,6 +20,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -20,6 +20,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Device
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
...
@@ -226,7 +227,7 @@ class LLMEngine:
...
@@ -226,7 +227,7 @@ class LLMEngine:
def
stop_profile
(
self
):
def
stop_profile
(
self
):
self
.
engine_core
.
profile
(
False
)
self
.
engine_core
.
profile
(
False
)
def
reset_prefix_cache
(
self
):
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
):
self
.
engine_core
.
reset_prefix_cache
()
self
.
engine_core
.
reset_prefix_cache
()
def
sleep
(
self
,
level
:
int
=
1
):
def
sleep
(
self
,
level
:
int
=
1
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment