Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7206ce4c
Unverified
Commit
7206ce4c
authored
Jan 22, 2025
by
Cody Yu
Committed by
GitHub
Jan 22, 2025
Browse files
[Core] Support `reset_prefix_cache` (#12284)
parent
96f6a759
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
72 additions
and
5 deletions
+72
-5
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+27
-0
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+3
-0
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+8
-1
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+3
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+9
-2
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+19
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+3
-0
No files found.
vllm/v1/core/kv_cache_manager.py
View file @
7206ce4c
...
@@ -285,6 +285,33 @@ class KVCacheManager:
...
@@ -285,6 +285,33 @@ class KVCacheManager:
if
block
.
ref_cnt
==
0
:
if
block
.
ref_cnt
==
0
:
self
.
free_block_queue
.
append
(
block
)
self
.
free_block_queue
.
append
(
block
)
def
reset_prefix_cache
(
self
)
->
bool
:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks
=
(
self
.
num_gpu_blocks
-
self
.
free_block_queue
.
num_free_blocks
)
if
num_used_blocks
>
0
:
logger
.
warning
(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet"
,
num_used_blocks
)
return
False
# Remove all hashes so that no new blocks will hit.
self
.
cached_block_hash_to_block
=
defaultdict
(
dict
)
# Remove all hashes from all blocks.
for
block
in
self
.
block_pool
:
block
.
reset_hash
()
logger
.
info
(
"Successfully reset prefix cache"
)
return
True
def
get_num_common_prefix_blocks
(
def
get_num_common_prefix_blocks
(
self
,
self
,
request
:
Request
,
request
:
Request
,
...
...
vllm/v1/core/scheduler.py
View file @
7206ce4c
...
@@ -529,6 +529,9 @@ class Scheduler:
...
@@ -529,6 +529,9 @@ class Scheduler:
def
has_unfinished_requests
(
self
)
->
bool
:
def
has_unfinished_requests
(
self
)
->
bool
:
return
self
.
get_num_unfinished_requests
()
>
0
return
self
.
get_num_unfinished_requests
()
>
0
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
kv_cache_manager
.
reset_prefix_cache
()
def
make_stats
(
self
)
->
SchedulerStats
:
def
make_stats
(
self
)
->
SchedulerStats
:
return
SchedulerStats
(
return
SchedulerStats
(
num_running_reqs
=
len
(
self
.
running
),
num_running_reqs
=
len
(
self
.
running
),
...
...
vllm/v1/engine/__init__.py
View file @
7206ce4c
...
@@ -66,6 +66,11 @@ class EngineCoreProfile:
...
@@ -66,6 +66,11 @@ class EngineCoreProfile:
is_start
:
bool
is_start
:
bool
@
dataclass
class
EngineCoreResetPrefixCache
:
pass
class
EngineCoreRequestType
(
enum
.
Enum
):
class
EngineCoreRequestType
(
enum
.
Enum
):
"""
"""
Request types defined as hex byte strings, so it can be sent over sockets
Request types defined as hex byte strings, so it can be sent over sockets
...
@@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum):
...
@@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum):
ADD
=
b
'
\x00
'
ADD
=
b
'
\x00
'
ABORT
=
b
'
\x01
'
ABORT
=
b
'
\x01
'
PROFILE
=
b
'
\x02
'
PROFILE
=
b
'
\x02
'
RESET_PREFIX_CACHE
=
b
'
\x03
'
EngineCoreRequestUnion
=
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]]
EngineCoreRequestUnion
=
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
EngineCoreResetPrefixCache
,
List
[
str
]]
vllm/v1/engine/async_llm.py
View file @
7206ce4c
...
@@ -321,6 +321,9 @@ class AsyncLLM(EngineClient):
...
@@ -321,6 +321,9 @@ class AsyncLLM(EngineClient):
async
def
stop_profile
(
self
)
->
None
:
async
def
stop_profile
(
self
)
->
None
:
await
self
.
engine_core
.
profile_async
(
False
)
await
self
.
engine_core
.
profile_async
(
False
)
async
def
reset_prefix_cache
(
self
)
->
None
:
await
self
.
engine_core
.
reset_prefix_cache_async
()
@
property
@
property
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
return
True
return
True
...
...
vllm/v1/engine/core.py
View file @
7206ce4c
...
@@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
...
@@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreProfile
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestUnion
)
EngineCoreRequestUnion
,
EngineCoreResetPrefixCache
)
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperServer
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
...
@@ -135,6 +135,9 @@ class EngineCore:
...
@@ -135,6 +135,9 @@ class EngineCore:
def
profile
(
self
,
is_start
:
bool
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
model_executor
.
profile
(
is_start
)
self
.
model_executor
.
profile
(
is_start
)
def
reset_prefix_cache
(
self
):
self
.
scheduler
.
reset_prefix_cache
()
class
EngineCoreProc
(
EngineCore
):
class
EngineCoreProc
(
EngineCore
):
"""ZMQ-wrapper for running EngineCore in background process."""
"""ZMQ-wrapper for running EngineCore in background process."""
...
@@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore):
...
@@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore):
self
.
add_request
(
request
)
self
.
add_request
(
request
)
elif
isinstance
(
request
,
EngineCoreProfile
):
elif
isinstance
(
request
,
EngineCoreProfile
):
self
.
model_executor
.
profile
(
request
.
is_start
)
self
.
model_executor
.
profile
(
request
.
is_start
)
elif
isinstance
(
request
,
EngineCoreResetPrefixCache
):
self
.
reset_prefix_cache
()
else
:
else
:
# TODO: make an EngineCoreAbort wrapper
# TODO: make an EngineCoreAbort wrapper
assert
isinstance
(
request
,
list
)
assert
isinstance
(
request
,
list
)
...
@@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore):
...
@@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore):
request
=
decoder_add_req
.
decode
(
request_data
)
request
=
decoder_add_req
.
decode
(
request_data
)
elif
request_type
==
EngineCoreRequestType
.
ABORT
.
value
:
elif
request_type
==
EngineCoreRequestType
.
ABORT
.
value
:
request
=
decoder_abort_req
.
decode
(
request_data
)
request
=
decoder_abort_req
.
decode
(
request_data
)
elif
request_type
==
EngineCoreRequestType
.
PROFILE
.
value
:
elif
request_type
in
(
EngineCoreRequestType
.
PROFILE
.
value
,
EngineCoreRequestType
.
RESET_PREFIX_CACHE
.
value
):
request
=
pickle
.
loads
(
request_data
)
request
=
pickle
.
loads
(
request_data
)
else
:
else
:
raise
ValueError
(
f
"Unknown RequestType:
{
request_type
}
"
)
raise
ValueError
(
f
"Unknown RequestType:
{
request_type
}
"
)
...
...
vllm/v1/engine/core_client.py
View file @
7206ce4c
...
@@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
...
@@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket
)
make_zmq_socket
)
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreProfile
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestUnion
)
EngineCoreRequestUnion
,
EngineCoreResetPrefixCache
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
PickleEncoder
from
vllm.v1.serial_utils
import
PickleEncoder
...
@@ -69,6 +69,9 @@ class EngineCoreClient(ABC):
...
@@ -69,6 +69,9 @@ class EngineCoreClient(ABC):
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
reset_prefix_cache
(
self
)
->
None
:
raise
NotImplementedError
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -81,6 +84,9 @@ class EngineCoreClient(ABC):
...
@@ -81,6 +84,9 @@ class EngineCoreClient(ABC):
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
async
def
reset_prefix_cache_async
(
self
)
->
None
:
raise
NotImplementedError
async
def
abort_requests_async
(
self
,
request_ids
:
List
[
str
])
->
None
:
async
def
abort_requests_async
(
self
,
request_ids
:
List
[
str
])
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient):
...
@@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient):
if
len
(
request_ids
)
>
0
:
if
len
(
request_ids
)
>
0
:
self
.
engine_core
.
abort_requests
(
request_ids
)
self
.
engine_core
.
abort_requests
(
request_ids
)
def
shutdown
(
self
):
def
shutdown
(
self
)
->
None
:
self
.
engine_core
.
shutdown
()
self
.
engine_core
.
shutdown
()
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
engine_core
.
profile
(
is_start
)
self
.
engine_core
.
profile
(
is_start
)
def
reset_prefix_cache
(
self
)
->
None
:
self
.
engine_core
.
reset_prefix_cache
()
class
MPClient
(
EngineCoreClient
):
class
MPClient
(
EngineCoreClient
):
"""
"""
...
@@ -229,6 +238,10 @@ class SyncMPClient(MPClient):
...
@@ -229,6 +238,10 @@ class SyncMPClient(MPClient):
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
EngineCoreProfile
(
is_start
))
EngineCoreProfile
(
is_start
))
def
reset_prefix_cache
(
self
)
->
None
:
self
.
_send_input
(
EngineCoreRequestType
.
RESET_PREFIX_CACHE
,
EngineCoreResetPrefixCache
())
class
AsyncMPClient
(
MPClient
):
class
AsyncMPClient
(
MPClient
):
"""Asyncio-compatible client for multi-proc EngineCore."""
"""Asyncio-compatible client for multi-proc EngineCore."""
...
@@ -266,3 +279,7 @@ class AsyncMPClient(MPClient):
...
@@ -266,3 +279,7 @@ class AsyncMPClient(MPClient):
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
await
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
await
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
EngineCoreProfile
(
is_start
))
EngineCoreProfile
(
is_start
))
async
def
reset_prefix_cache_async
(
self
)
->
None
:
await
self
.
_send_input
(
EngineCoreRequestType
.
RESET_PREFIX_CACHE
,
EngineCoreResetPrefixCache
())
vllm/v1/engine/llm_engine.py
View file @
7206ce4c
...
@@ -162,6 +162,9 @@ class LLMEngine:
...
@@ -162,6 +162,9 @@ class LLMEngine:
def
stop_profile
(
self
):
def
stop_profile
(
self
):
self
.
engine_core
.
profile
(
False
)
self
.
engine_core
.
profile
(
False
)
def
reset_prefix_cache
(
self
):
self
.
engine_core
.
reset_prefix_cache
()
def
get_tokenizer_group
(
def
get_tokenizer_group
(
self
,
self
,
group_type
:
Type
[
_G
]
=
BaseTokenizerGroup
,
group_type
:
Type
[
_G
]
=
BaseTokenizerGroup
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment