Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc7f22a8
Commit
cc7f22a8
authored
Jun 11, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.1' into v0.9.1-ori
parents
b9ea0c09
b6553be1
Changes
1000
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
204 additions
and
31 deletions
+204
-31
vllm/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+1
-0
vllm/distributed/device_communicators/hpu_communicator.py
vllm/distributed/device_communicators/hpu_communicator.py
+1
-0
vllm/distributed/device_communicators/neuron_communicator.py
vllm/distributed/device_communicators/neuron_communicator.py
+1
-0
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+1
-0
vllm/distributed/device_communicators/pynccl_wrapper.py
vllm/distributed/device_communicators/pynccl_wrapper.py
+1
-0
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+46
-2
vllm/distributed/device_communicators/tpu_communicator.py
vllm/distributed/device_communicators/tpu_communicator.py
+1
-0
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+1
-0
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+69
-9
vllm/distributed/kv_transfer/__init__.py
vllm/distributed/kv_transfer/__init__.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/base.py
vllm/distributed/kv_transfer/kv_connector/base.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+3
-1
vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
...distributed/kv_transfer/kv_connector/lmcache_connector.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py
...uted/kv_transfer/kv_connector/mooncake_store_connector.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
.../distributed/kv_transfer/kv_connector/simple_connector.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+18
-1
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+27
-3
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+27
-15
No files found.
Too many changes to show.
To preserve performance only
1000 of 1000+
files are displayed.
Plain diff
Email patch
vllm/distributed/device_communicators/custom_all_reduce_utils.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ctypes
import
json
...
...
vllm/distributed/device_communicators/hpu_communicator.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.distributed
as
dist
...
...
vllm/distributed/device_communicators/neuron_communicator.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.distributed.device_communicators.base_device_communicator
import
(
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
,
Union
...
...
vllm/distributed/device_communicators/pynccl_wrapper.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pickle
import
time
...
...
@@ -27,6 +28,43 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
logger
=
init_logger
(
__name__
)
class
SpinTimer
:
def
record_activity
(
self
):
pass
def
spin
(
self
):
sched_yield
()
class
SpinSleepTimer
(
SpinTimer
):
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when vllm does nothing. This would lead to more
CPU thermal headroom when a request eventually comes, especially when
multiple GPUs are connected as each GPU would otherwise pin one thread at
100% CPU usage.
The simplest solution is to reduce polling frequency when there is no
activity for a certain period of time.
"""
def
__init__
(
self
,
busy_loop_s
:
float
=
3.0
,
wait_sleep_s
:
float
=
0.1
):
self
.
last_activity
=
time
.
monotonic
()
self
.
busy_loop_s
=
busy_loop_s
self
.
wait_sleep_s
=
wait_sleep_s
def
record_activity
(
self
):
self
.
last_activity
=
time
.
monotonic
()
def
spin
(
self
):
curr_time
=
time
.
monotonic
()
if
curr_time
>=
self
.
last_activity
+
self
.
busy_loop_s
:
time
.
sleep
(
self
.
wait_sleep_s
)
else
:
sched_yield
()
class
ShmRingBuffer
:
def
__init__
(
self
,
...
...
@@ -41,7 +79,7 @@ class ShmRingBuffer:
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.
Buffer memory layout:
data metadata
| |
...
...
@@ -237,6 +275,7 @@ class MessageQueue:
self
.
local_reader_rank
=
-
1
# rank does not matter for remote readers
self
.
_is_remote_reader
=
False
self
.
_read_spin_timer
=
SpinTimer
()
self
.
handle
=
Handle
(
local_reader_ranks
=
local_reader_ranks
,
...
...
@@ -275,6 +314,9 @@ class MessageQueue:
self
.
local_socket
.
connect
(
socket_addr
)
self
.
remote_socket
=
None
self
.
_read_spin_timer
=
SpinSleepTimer
(
)
if
envs
.
VLLM_SLEEP_WHEN_IDLE
else
SpinTimer
()
else
:
self
.
buffer
=
None
# type: ignore
self
.
current_idx
=
-
1
...
...
@@ -406,7 +448,7 @@ class MessageQueue:
# we need to wait until it is written
# Release the processor to other threads
s
ched_yield
()
s
elf
.
_read_spin_timer
.
spin
()
# if we wait for a long time, log a message
if
(
time
.
monotonic
()
-
start_time
...
...
@@ -437,6 +479,8 @@ class MessageQueue:
metadata_buffer
[
self
.
local_reader_rank
+
1
]
=
1
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
self
.
_read_spin_timer
.
record_activity
()
break
def
enqueue
(
self
,
obj
,
timeout
:
Optional
[
float
]
=
None
):
...
...
vllm/distributed/device_communicators/tpu_communicator.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
Optional
...
...
vllm/distributed/device_communicators/xpu_communicator.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
...
...
vllm/distributed/kv_events.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
queue
import
threading
...
...
@@ -27,6 +28,7 @@ class EventBatch(
):
ts
:
float
events
:
list
[
Any
]
data_parallel_rank
:
Optional
[
int
]
=
None
class
KVCacheEvent
(
...
...
@@ -59,7 +61,22 @@ class KVEventBatch(EventBatch):
class
EventPublisher
(
ABC
):
"""Lightweight publisher for EventBatch batches."""
"""Lightweight publisher for EventBatch batches with data parallelism
support.
In data parallel setups, each DP rank runs its own EventPublisher instance
to avoid duplicate events and ensure proper event attribution:
- Each DP rank creates a separate publisher
- Publishers automatically annotate events with their data_parallel_rank
- This allows consumers to distinguish events from different DP ranks
The publisher is responsible for adding DP metadata since the scheduler
operates independently of DP topology and shouldn't need DP awareness.
"""
def
__init__
(
self
,
data_parallel_rank
:
int
=
0
)
->
None
:
self
.
_data_parallel_rank
=
data_parallel_rank
@
abstractmethod
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
...
...
@@ -112,6 +129,7 @@ class ZmqEventPublisher(EventPublisher):
def
__init__
(
self
,
data_parallel_rank
:
int
,
endpoint
:
str
=
"tcp://*:5557"
,
replay_endpoint
:
Optional
[
str
]
=
None
,
buffer_steps
:
int
=
10_000
,
...
...
@@ -120,6 +138,7 @@ class ZmqEventPublisher(EventPublisher):
topic
:
str
=
""
,
)
->
None
:
# Storage
super
().
__init__
(
data_parallel_rank
)
self
.
_event_queue
=
Queue
[
Optional
[
EventBatch
]](
maxsize
=
max_queue_size
)
self
.
_buffer
=
deque
[
tuple
[
int
,
bytes
]](
maxlen
=
buffer_steps
)
...
...
@@ -127,8 +146,11 @@ class ZmqEventPublisher(EventPublisher):
self
.
_ctx
=
zmq
.
Context
.
instance
()
self
.
_pub
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_replay
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_endpoint
=
endpoint
self
.
_replay_endpoint
=
replay_endpoint
self
.
_dp_rank
=
data_parallel_rank
self
.
_endpoint
=
self
.
offset_endpoint_port
(
endpoint
,
self
.
_dp_rank
)
self
.
_replay_endpoint
=
self
.
offset_endpoint_port
(
replay_endpoint
,
self
.
_dp_rank
)
self
.
_hwm
=
hwm
self
.
_socket_setup
()
...
...
@@ -148,6 +170,8 @@ class ZmqEventPublisher(EventPublisher):
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
if
not
self
.
_running
:
raise
RuntimeError
(
"Publisher is closed"
)
if
events
.
data_parallel_rank
is
None
:
events
.
data_parallel_rank
=
self
.
_data_parallel_rank
self
.
_event_queue
.
put
(
events
)
def
shutdown
(
self
)
->
None
:
...
...
@@ -190,11 +214,12 @@ class ZmqEventPublisher(EventPublisher):
self
.
_pub
.
set_hwm
(
self
.
_hwm
)
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
if
(
"*"
in
self
.
_endpoint
or
"::"
in
self
.
_endpoint
or
self
.
_endpoint
.
startswith
(
"ipc://"
)
or
self
.
_endpoint
.
startswith
(
"inproc://"
)):
if
(
self
.
_endpoint
is
not
None
and
(
"*"
in
self
.
_endpoint
or
"::"
in
self
.
_endpoint
or
self
.
_endpoint
.
startswith
(
"ipc://"
)
or
self
.
_endpoint
.
startswith
(
"inproc://"
))):
self
.
_pub
.
bind
(
self
.
_endpoint
)
el
s
e
:
el
if
self
.
_endpoint
is
not
Non
e
:
self
.
_pub
.
connect
(
self
.
_endpoint
)
# Set up replay socket: use ROUTER
...
...
@@ -265,6 +290,38 @@ class ZmqEventPublisher(EventPublisher):
# receiving payload is (-1, b""")
self
.
_replay
.
send_multipart
((
client_id
,
b
""
,
self
.
END_SEQ
,
b
""
))
@
staticmethod
def
offset_endpoint_port
(
endpoint
:
Optional
[
str
],
data_parallel_rank
:
int
)
->
Optional
[
str
]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if
not
endpoint
or
data_parallel_rank
==
0
:
return
endpoint
if
"inproc"
in
endpoint
:
return
f
"
{
endpoint
}
_dp
{
data_parallel_rank
}
"
if
"tcp"
in
endpoint
:
if
endpoint
and
":"
in
endpoint
:
# Get everything after the last colon (the port)
last_colon_idx
=
endpoint
.
rfind
(
":"
)
base_addr
=
endpoint
[:
last_colon_idx
]
base_port
=
int
(
endpoint
[
last_colon_idx
+
1
:])
new_port
=
base_port
+
data_parallel_rank
return
f
"
{
base_addr
}
:
{
new_port
}
"
return
endpoint
raise
ValueError
(
"Invalid endpoint: must contain 'inproc' or 'tcp'"
)
class
EventPublisherFactory
:
_registry
:
dict
[
str
,
Callable
[...,
EventPublisher
]]
=
{
...
...
@@ -280,7 +337,9 @@ class EventPublisherFactory:
cls
.
_registry
[
name
]
=
ctor
@
classmethod
def
create
(
cls
,
config
:
Optional
[
KVEventsConfig
])
->
EventPublisher
:
def
create
(
cls
,
config
:
Optional
[
KVEventsConfig
],
data_parallel_rank
:
int
=
0
)
->
EventPublisher
:
"""Create publisher from a config mapping."""
if
not
config
:
return
NullEventPublisher
()
...
...
@@ -293,4 +352,5 @@ class EventPublisherFactory:
constructor
=
cls
.
_registry
[
kind
]
except
KeyError
as
exc
:
raise
ValueError
(
f
"Unknown event publisher '
{
kind
}
'"
)
from
exc
return
constructor
(
**
config_dict
)
return
constructor
(
data_parallel_rank
=
data_parallel_rank
,
**
config_dict
)
vllm/distributed/kv_transfer/__init__.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
KVConnectorBaseType
,
ensure_kv_transfer_initialized
,
get_kv_transfer_group
,
...
...
vllm/distributed/kv_transfer/kv_connector/base.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
...
...
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
from
typing
import
TYPE_CHECKING
,
Callable
...
...
@@ -70,7 +71,8 @@ class KVConnectorFactory:
connector_module
=
importlib
.
import_module
(
connector_module_path
)
connector_cls
=
getattr
(
connector_module
,
connector_name
)
assert
issubclass
(
connector_cls
,
KVConnectorBase_V1
)
logger
.
info
(
"Creating v1 connector with name: %s"
,
connector_name
)
logger
.
info
(
"Creating v1 connector with name: %s and engine_id: %s"
,
connector_name
,
kv_transfer_config
.
engine_id
)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
...
...
vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
LMCache KV Cache Connector for Distributed Machine Learning Inference
...
...
vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
...
...
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV cache helper for store.
"""
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
...
...
@@ -89,3 +91,18 @@ class model_aware_kv_ops_helper:
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
def
get_kv_connector_cache_layout
():
vllm_config
=
get_current_vllm_config
()
kv_config
=
vllm_config
.
kv_transfer_config
if
vllm_config
.
model_config
is
None
:
logger
.
warning
(
"Unable to detect current VLLM config. "
\
"Defaulting to NHD kv cache layout."
)
else
:
use_mla
=
vllm_config
.
model_config
.
use_mla
if
not
use_mla
and
kv_config
.
kv_connector
==
"NixlConnector"
:
logger
.
info
(
"NixlConnector detected. Setting KV cache "
\
"layout to HND for better xfer performance."
)
return
"HND"
return
"NHD"
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorRole
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
communication in vLLM v1
...
...
@@ -7,9 +8,15 @@ The class provides the following primitives:
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save KV cache.
get_num_new_matched_tokens() - get number of new tokens
that exist in the remote KV cache
that exist in the remote KV cache. Might be called multiple
times for a given request and should be side-effect free.
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
request_finished() - called when a request is finished, with
the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
...
...
@@ -18,6 +25,9 @@ The class provides the following primitives:
save_kv_layer() - starts saving KV for layer i (maybe async)
wait_for_save() - blocks until all saves are done
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
"""
import
enum
...
...
@@ -183,7 +193,8 @@ class KVConnectorBase_V1(ABC):
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
...
...
@@ -214,7 +225,8 @@ class KVConnectorBase_V1(ABC):
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps).
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
"""
pass
...
...
@@ -224,6 +236,18 @@ class KVConnectorBase_V1(ABC):
num_external_tokens
:
int
):
"""
Update KVConnector state after block allocation.
If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.
Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
pass
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
import
torch
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
cc7f22a8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
...
...
@@ -11,12 +12,12 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
...
...
@@ -50,8 +51,9 @@ class MultiConnector(KVConnectorBase_V1):
self
.
_connectors
.
append
(
KVConnectorFactory
.
create_connector_v1
(
temp_config
,
role
))
# A mapping from request id to the connector that is assigned to it.
self
.
_requests_to_connector
:
dict
[
str
,
KVConnectorBase_V1
]
=
{}
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
self
.
_requests_to_connector
:
dict
[
str
,
int
]
=
{}
# Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow
...
...
@@ -135,25 +137,31 @@ class MultiConnector(KVConnectorBase_V1):
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
,
bool
]:
for
c
in
self
.
_connectors
:
to_return
=
(
0
,
False
)
for
i
,
c
in
enumerate
(
self
.
_connectors
):
toks
,
load_async
=
c
.
get_num_new_matched_tokens
(
request
,
num_computed_tokens
)
# The first connector that has new matched tokens will be assigned
# to this request.
if
toks
>
0
:
self
.
_requests_to_connector
[
request
.
request_id
]
=
c
return
toks
,
load_async
return
0
,
False
if
to_return
[
0
]
==
0
and
toks
>
0
:
self
.
_requests_to_connector
[
request
.
request_id
]
=
i
to_
return
=
(
toks
,
load_async
)
return
to_return
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
):
# If the request is not assigned to any connector, we do nothing.
if
request
.
request_id
not
in
self
.
_requests_to_connector
:
return
# We assume that the request is assigned to only one connector.
c
=
self
.
_requests_to_connector
.
pop
(
request
.
request_id
)
c
.
update_state_after_alloc
(
request
,
blocks
,
num_external_tokens
)
chosen_connector
=
self
.
_requests_to_connector
.
get
(
request
.
request_id
,
-
1
)
empty_blocks
=
blocks
.
new_empty
()
for
i
,
c
in
enumerate
(
self
.
_connectors
):
if
i
==
chosen_connector
:
# Forward call to the chosen connector (if any).
c
.
update_state_after_alloc
(
request
,
blocks
,
num_external_tokens
)
else
:
# Call with empty blocks for other connectors.
c
.
update_state_after_alloc
(
request
,
empty_blocks
,
0
)
def
build_connector_meta
(
self
,
...
...
@@ -169,7 +177,7 @@ class MultiConnector(KVConnectorBase_V1):
def
request_finished
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
blocks
:
list
[
int
]
,
)
->
tuple
[
bool
,
Optional
[
dict
[
str
,
Any
]]]:
async_saves
=
0
kv_txfer_params
=
None
...
...
@@ -186,4 +194,8 @@ class MultiConnector(KVConnectorBase_V1):
kv_txfer_params
=
txfer_params
if
async_saves
>
1
:
self
.
_extra_async_saves
[
request
.
request_id
]
=
async_saves
-
1
# Clean up other state for this request.
self
.
_requests_to_connector
.
pop
(
request
.
request_id
,
None
)
return
async_saves
>
0
,
kv_txfer_params
Prev
1
…
42
43
44
45
46
47
48
49
50
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment