Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38d80967
Commit
38d80967
authored
Sep 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori
parents
33650733
880c741b
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
225 additions
and
84 deletions
+225
-84
vllm/distributed/device_communicators/all_reduce_utils.py
vllm/distributed/device_communicators/all_reduce_utils.py
+2
-2
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+4
-1
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+8
-5
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+5
-4
vllm/distributed/device_communicators/neuron_communicator.py
vllm/distributed/device_communicators/neuron_communicator.py
+0
-20
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+1
-1
vllm/distributed/device_communicators/quick_all_reduce.py
vllm/distributed/device_communicators/quick_all_reduce.py
+1
-1
vllm/distributed/device_communicators/ray_communicator.py
vllm/distributed/device_communicators/ray_communicator.py
+1
-1
vllm/distributed/device_communicators/symm_mem.py
vllm/distributed/device_communicators/symm_mem.py
+31
-6
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+5
-4
vllm/distributed/kv_transfer/__init__.py
vllm/distributed/kv_transfer/__init__.py
+4
-3
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+2
-1
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+49
-1
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+14
-3
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+19
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+74
-24
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
...d/kv_transfer/kv_connector/v1/shared_storage_connector.py
+2
-2
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
vllm/distributed/device_communicators/all_reduce_utils.py
View file @
38d80967
...
...
@@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = {
"10.0"
:
{
2
:
2
*
MiB
,
# 2 MB
4
:
2
*
MiB
,
# 2 MB
6
:
2
*
MiB
,
#
2
MB
8
:
2
*
MiB
,
#
2
MB
6
:
1
*
MiB
,
#
1
MB
8
:
1
*
MiB
,
#
1
MB
}
}
...
...
vllm/distributed/device_communicators/base_device_communicator.py
View file @
38d80967
...
...
@@ -252,7 +252,10 @@ class DeviceCommunicatorBase:
moe_modules
=
[
module
for
module
in
model
.
modules
()
if
module
.
__class__
.
__name__
==
"FusedMoE"
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if
(
module
.
__class__
.
__name__
==
"FusedMoE"
or
module
.
__class__
.
__name__
==
"SharedFusedMoE"
)
]
for
module
in
moe_modules
:
module
.
quant_method
.
init_prepare_finalize
(
module
)
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
38d80967
...
...
@@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
self
.
qr_comm
:
Optional
[
QuickAllReduce
]
=
None
self
.
symm_mem_comm
:
Optional
[
SymmMemCommunicator
]
=
None
if
envs
.
VLLM_ALLREDUCE_USE_SYMM_MEM
and
current_platform
.
is_cuda
():
self
.
symm_mem_comm
=
SymmMemCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
symm_mem_enabled
=
(
self
.
symm_mem_comm
is
not
None
and
not
self
.
symm_mem_comm
.
disabled
),
)
if
current_platform
.
is_rocm
():
...
...
@@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series.
self
.
qr_comm
=
QuickAllReduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
)
if
envs
.
VLLM_ALLREDUCE_USE_SYMM_MEM
and
current_platform
.
is_cuda
():
self
.
symm_mem_comm
=
SymmMemCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
if
self
.
use_all2all
:
all2all_backend
=
envs
.
VLLM_ALL2ALL_BACKEND
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
38d80967
...
...
@@ -54,13 +54,14 @@ class CustomAllreduce:
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
1024
)
->
None
:
max_size
=
8192
*
1024
,
symm_mem_enabled
=
False
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be b
i
nd to f"cuda:{local_rank}".
it will be b
ou
nd to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
...
...
@@ -111,7 +112,7 @@ class CustomAllreduce:
self
.
device
=
device
device_capability
=
current_platform
.
get_device_capability
(
).
as_version_str
()
if
(
current_platform
.
is_cuda
()
and
envs
.
VLLM_ALLREDUCE_USE_SYMM_MEM
if
(
current_platform
.
is_cuda
()
and
symm_mem_enabled
and
device_capability
in
CUSTOM_ALL_REDUCE_MAX_SIZES
):
max_size
=
min
(
CUSTOM_ALL_REDUCE_MAX_SIZES
[
device_capability
][
world_size
],
...
...
@@ -159,7 +160,7 @@ class CustomAllreduce:
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Meta
data composes of two parts: meta
data for synchronization and a
# Metadata composes of two parts: metadata for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
,
...
...
vllm/distributed/device_communicators/neuron_communicator.py
deleted
100644 → 0
View file @
33650733
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
)
from
vllm.platforms
import
current_platform
if
current_platform
.
is_neuron
():
import
torch_xla.core.xla_model
as
xm
class
NeuronCommunicator
(
DeviceCommunicatorBase
):
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
x
)
def
all_gather
(
self
,
x
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
assert
dim
==
-
1
,
"Neuron only supports dim=-1 for all-gather."
return
xm
.
all_gather
(
x
,
dim
=
dim
)
vllm/distributed/device_communicators/pynccl.py
View file @
38d80967
...
...
@@ -31,7 +31,7 @@ class PyNcclCommunicator:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be b
i
nd to f"cuda:{local_rank}".
it will be b
ou
nd to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
...
...
vllm/distributed/device_communicators/quick_all_reduce.py
View file @
38d80967
...
...
@@ -78,7 +78,7 @@ class QuickAllReduce:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be b
i
nd to f"cuda:{local_rank}".
it will be b
ou
nd to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
...
...
vllm/distributed/device_communicators/ray_communicator.py
View file @
38d80967
...
...
@@ -186,7 +186,7 @@ class RayPPCommunicator(Communicator):
"""
Receive a torch.Tensor from a peer and synchronize the current stream.
After this call returns, the receive buffer is safe to read from
from
After this call returns, the receive buffer is safe to read from
any stream. An RayChannelError will be raised if an error occurred
(e.g., remote actor died), and the buffer is not safe to read.
...
...
vllm/distributed/device_communicators/symm_mem.py
View file @
38d80967
...
...
@@ -27,8 +27,13 @@ class SymmMemCommunicator:
"10.0"
:
[
6
,
8
],
}
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
]):
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
# add options for testing
force_multimem
:
Optional
[
bool
]
=
None
,
max_size_override
:
Optional
[
int
]
=
None
):
self
.
disabled
=
True
if
not
symm_mem_available
:
...
...
@@ -64,8 +69,17 @@ class SymmMemCommunicator:
self
.
world_size
,
)
return
self
.
max_size
=
SYMM_MEM_ALL_REDUCE_MAX_SIZES
[
self
.
device_capability
][
self
.
world_size
]
# Use override max_size if provided, otherwise use default
if
max_size_override
is
not
None
:
self
.
max_size
=
max_size_override
logger
.
info
(
"SymmMemCommunicator: Using override max_size: %s bytes"
,
self
.
max_size
,
)
else
:
self
.
max_size
=
SYMM_MEM_ALL_REDUCE_MAX_SIZES
[
self
.
device_capability
][
self
.
world_size
]
self
.
buffer
=
torch_symm_mem
.
empty
(
self
.
max_size
//
self
.
dtype
.
itemsize
,
device
=
self
.
device
,
...
...
@@ -76,6 +90,7 @@ class SymmMemCommunicator:
logger
.
warning
(
"SymmMemCommunicator: symmetric memory "
"multicast operations are not supported."
)
return
self
.
force_multimem
=
force_multimem
self
.
disabled
=
False
def
should_use_symm_mem
(
self
,
inp
:
torch
.
Tensor
):
...
...
@@ -98,8 +113,18 @@ class SymmMemCommunicator:
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
self
.
buffer
[:
inp
.
numel
()].
copy_
(
inp
.
view
(
-
1
))
if
self
.
world_size
in
self
.
_WORLD_SIZES_MULTIMEM
[
self
.
device_capability
]:
# Determine which algorithm to use
use_multimem
=
False
if
self
.
force_multimem
is
not
None
:
# Test override: use forced setting
use_multimem
=
self
.
force_multimem
else
:
# Normal logic: use multimem for supported world sizes
use_multimem
=
self
.
world_size
in
self
.
_WORLD_SIZES_MULTIMEM
[
self
.
device_capability
]
if
use_multimem
:
torch
.
ops
.
symm_mem
.
multimem_all_reduce_
(
self
.
buffer
[:
inp
.
numel
()],
"sum"
,
self
.
group
.
group_name
)
...
...
vllm/distributed/kv_events.py
View file @
38d80967
...
...
@@ -14,8 +14,9 @@ from typing import Any, Callable, Optional, Union
import
msgspec
import
zmq
from
vllm.config
import
KVEventsConfig
from
vllm.config
.kv_events
import
KVEventsConfig
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
ExternalBlockHash
logger
=
init_logger
(
__name__
)
...
...
@@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
class
BlockStored
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
parent_block_hash
:
Optional
[
int
]
block_hashes
:
list
[
ExternalBlockHash
]
parent_block_hash
:
Optional
[
ExternalBlockHash
]
token_ids
:
list
[
int
]
block_size
:
int
lora_id
:
Optional
[
int
]
...
...
@@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
block_hashes
:
list
[
ExternalBlockHash
]
medium
:
Optional
[
str
]
...
...
vllm/distributed/kv_transfer/__init__.py
View file @
38d80967
...
...
@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
KVConnectorBaseType
,
ensure_kv_transfer_initialized
,
get_kv_transfer_group
,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
KVConnectorBaseType
,
ensure_kv_transfer_initialized
,
ensure_kv_transfer_shutdown
,
get_kv_transfer_group
,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
__all__
=
[
"get_kv_transfer_group"
,
"has_kv_transfer_group"
,
"is_v1_kv_transfer_group"
,
"ensure_kv_transfer_initialized"
,
"KVConnectorBaseType"
"ensure_kv_transfer_shutdown"
,
"KVConnectorBaseType"
]
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
38d80967
...
...
@@ -14,7 +14,8 @@ from vllm.logger import init_logger
# yapf: enable
if
TYPE_CHECKING
:
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
logger
=
init_logger
(
__name__
)
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
38d80967
...
...
@@ -6,7 +6,7 @@ KV cache helper for store.
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
concurrent.futures
import
CancelledError
,
Future
from
typing
import
Optional
,
cast
from
typing
import
Literal
,
Optional
,
Union
,
cast
import
torch
...
...
@@ -196,3 +196,51 @@ class KVOutputAggregator:
output_future
.
add_done_callback
(
make_callback
(
i
))
return
result_future
def
_make_src_and_dst_indices
(
src_block_ids
:
list
[
int
],
dst_block_ids
:
list
[
int
],
src_device
:
Union
[
torch
.
device
,
str
],
dst_device
:
Union
[
torch
.
device
,
str
],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
src_indices
=
torch
.
tensor
(
src_block_ids
,
device
=
src_device
,
dtype
=
torch
.
int64
)
dst_indices
=
torch
.
tensor
(
dst_block_ids
,
device
=
dst_device
,
dtype
=
torch
.
int64
)
return
src_indices
,
dst_indices
def
copy_kv_blocks
(
src_kv_caches
:
dict
[
str
,
torch
.
Tensor
],
dst_kv_caches
:
dict
[
str
,
torch
.
Tensor
],
src_block_ids
:
list
[
int
],
dst_block_ids
:
list
[
int
],
direction
:
Literal
[
"h2d"
,
"d2h"
],
)
->
None
:
"""Copy kv blocks between different buffers."""
if
not
src_kv_caches
or
not
dst_kv_caches
or
\
not
src_block_ids
or
not
dst_block_ids
or
\
len
(
src_block_ids
)
!=
len
(
dst_block_ids
):
return
src_device
=
next
(
iter
(
src_kv_caches
.
values
())).
device
dst_device
=
next
(
iter
(
dst_kv_caches
.
values
())).
device
src_indices
,
dst_indices
=
_make_src_and_dst_indices
(
src_block_ids
=
src_block_ids
,
dst_block_ids
=
dst_block_ids
,
src_device
=
src_device
,
dst_device
=
dst_device
)
from
vllm.platforms
import
current_platform
if
direction
==
"h2d"
:
copy_fn
=
current_platform
.
insert_blocks_to_device
else
:
copy_fn
=
current_platform
.
swap_out_blocks_to_host
for
layer_name
in
src_kv_caches
:
src_tensor
=
src_kv_caches
[
layer_name
]
dst_tensor
=
dst_kv_caches
[
layer_name
]
copy_fn
(
src_tensor
,
dst_tensor
,
src_indices
,
dst_indices
)
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
38d80967
...
...
@@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC):
"""
return
None
,
None
def
shutdown
(
self
):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
return
None
# ==============================
# Scheduler-side methods
# ==============================
...
...
@@ -235,7 +243,7 @@ class KVConnectorBase_V1(ABC):
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
,
bool
]:
)
->
tuple
[
Optional
[
int
]
,
bool
]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
...
...
@@ -247,8 +255,11 @@ class KVConnectorBase_V1(ABC):
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
View file @
38d80967
...
...
@@ -110,7 +110,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
,
bool
]:
)
->
tuple
[
Optional
[
int
]
,
bool
]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
38d80967
...
...
@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional
import
torch
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
...
...
@@ -87,6 +88,18 @@ class MultiConnector(KVConnectorBase_V1):
for
c
in
self
.
_connectors
:
c
.
clear_connector_metadata
()
def
shutdown
(
self
):
exception
:
Optional
[
Exception
]
=
None
for
c
in
self
.
_connectors
:
try
:
c
.
shutdown
()
except
Exception
as
e
:
logger
.
exception
(
"Exception during connector %s shutdown."
,
c
.
__class__
.
__name__
)
exception
=
e
if
exception
:
raise
exception
# ==============================
# Worker-side methods
# ==============================
...
...
@@ -142,11 +155,15 @@ class MultiConnector(KVConnectorBase_V1):
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
,
bool
]:
)
->
tuple
[
Optional
[
int
]
,
bool
]:
to_return
=
(
0
,
False
)
for
i
,
c
in
enumerate
(
self
.
_connectors
):
toks
,
load_async
=
c
.
get_num_new_matched_tokens
(
request
,
num_computed_tokens
)
# If there is a connector still looking up the matches,
# we return None to indicate that we are not done yet.
if
toks
is
None
:
return
(
None
,
False
)
# The first connector that has new matched tokens will be assigned
# to this request.
if
to_return
[
0
]
==
0
and
toks
>
0
:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
38d80967
...
...
@@ -14,6 +14,7 @@ from dataclasses import dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgspec
import
numpy
as
np
import
torch
import
zmq
...
...
@@ -60,6 +61,7 @@ except ImportError:
_NIXL_SUPPORTED_XPUS
=
{
"cuda"
:
(
"cuda"
,
),
"tpu"
:
(
"cpu"
,
),
"xpu"
:
(
"cpu"
,
),
}
...
...
@@ -160,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1):
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
)
->
tuple
[
int
,
bool
]:
num_computed_tokens
:
int
)
->
tuple
[
Optional
[
int
]
,
bool
]:
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
get_num_new_matched_tokens
(
request
,
num_computed_tokens
)
...
...
@@ -715,7 +717,7 @@ class NixlConnectorWorker:
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are
transfer
red in the same
tensor
# Conversely for FlashInfer, K and V are
registe
red in the same
region
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
_use_pallas_v1
or
self
.
_use_flashinfer
)
...
...
@@ -758,12 +760,21 @@ class NixlConnectorWorker:
assert
tensor_size_bytes
%
self
.
num_blocks
==
0
self
.
block_len
=
tensor_size_bytes
//
self
.
num_blocks
self
.
slot_size_bytes
=
self
.
block_len
//
self
.
block_size
self
.
device_kv_caches
=
kv_caches
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
if
self
.
_use_flashinfer
:
assert
self
.
slot_size_bytes
%
2
==
0
self
.
slot_size_bytes
/=
2
self
.
device_kv_caches
=
kv_caches
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence we double the number
# of 'virtual' regions here and halve `block_len` below.
self
.
num_regions
*=
2
kv_block_len
=
self
.
get_backend_aware_kv_block_len
()
# Register local/src descr for NIXL xfer.
blocks_data
=
[]
for
base_addr
in
seen_base_addresses
:
...
...
@@ -776,8 +787,18 @@ class NixlConnectorWorker:
block_offset
=
block_id
*
self
.
block_len
addr
=
base_addr
+
block_offset
# (addr, len, device id)
# TODO: does device_id matter to DRAM?
blocks_data
.
append
((
addr
,
self
.
block_len
,
self
.
tp_rank
))
blocks_data
.
append
((
addr
,
kv_block_len
,
self
.
tp_rank
))
if
self
.
_use_flashinfer
:
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for
block_id
in
range
(
self
.
num_blocks
):
block_offset
=
block_id
*
self
.
block_len
addr
=
base_addr
+
block_offset
# Register addresses for V cache (K registered first).
v_addr
=
addr
+
kv_block_len
blocks_data
.
append
((
v_addr
,
kv_block_len
,
self
.
tp_rank
))
logger
.
debug
(
"Created %s blocks for src engine %s and rank %s"
,
len
(
blocks_data
),
self
.
engine_id
,
self
.
tp_rank
)
...
...
@@ -787,7 +808,7 @@ class NixlConnectorWorker:
self
.
src_xfer_side_handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# TODO(mgoin): Hybrid memory allocator is currently di
s
abled for
# models with local attention (Llama 4). Can remove this once enabled.
if
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
...
...
@@ -903,7 +924,7 @@ class NixlConnectorWorker:
remote_block_size
=
nixl_agent_meta
.
block_len
//
(
self
.
slot_size_bytes
*
tp_ratio
)
if
self
.
_use_flashinfer
:
#
Account for joint KV in FlashInfer
.
#
With flashinfer, KV are sent in the same message
.
remote_block_size
//=
2
if
tp_ratio
>
1
:
# Heterogeneous TP expects same kv_cache_layout.
...
...
@@ -929,10 +950,10 @@ class NixlConnectorWorker:
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Only register the remote's descriptors if current rank pulls from it.
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
self
.
block_len
\
kv_block_len
=
self
.
get_backend_aware_kv_block_len
()
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
kv_block_len
\
if
not
(
self
.
use_mla
or
is_kv_replicated
)
else
0
# Register all remote blocks, but only the corresponding kv heads.
for
base_addr
in
nixl_agent_meta
.
kv_caches_base_addr
:
...
...
@@ -943,7 +964,16 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
blocks_data
.
append
((
addr
,
self
.
block_len
,
remote_tp_rank
))
blocks_data
.
append
((
addr
,
kv_block_len
,
remote_tp_rank
))
if
self
.
_use_flashinfer
:
# With FlashInfer index V separately to allow head splitting.
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_len
addr
=
base_addr
+
block_offset
+
rank_offset
v_addr
=
addr
+
nixl_agent_meta
.
block_len
//
2
blocks_data
.
append
((
v_addr
,
kv_block_len
,
remote_tp_rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s"
,
len
(
blocks_data
),
engine_id
,
remote_tp_rank
,
...
...
@@ -1163,8 +1193,8 @@ class NixlConnectorWorker:
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
local_block_descs_ids
:
list
[
int
]
=
[]
remote_block_descs_ids
:
list
[
int
]
=
[]
local_block_descs_ids
:
np
.
ndarray
remote_block_descs_ids
:
np
.
ndarray
if
not
self
.
block_window_per_layer
:
# Default case: assume global attention
remote_block_descs_ids
=
self
.
_get_block_descs_ids
(
...
...
@@ -1174,6 +1204,8 @@ class NixlConnectorWorker:
else
:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
local_descs_list
=
[]
remote_descs_list
=
[]
for
layer_idx
,
block_window
in
enumerate
(
self
.
block_window_per_layer
):
# For each layer:
...
...
@@ -1193,8 +1225,11 @@ class NixlConnectorWorker:
layer_remote_desc_ids
=
self
.
_get_block_descs_ids
(
dst_engine_id
,
layer_remote_block_ids
,
layer_idx
)
local_block_descs_ids
.
extend
(
layer_local_desc_ids
)
remote_block_descs_ids
.
extend
(
layer_remote_desc_ids
)
local_descs_list
.
append
(
layer_local_desc_ids
)
remote_descs_list
.
append
(
layer_remote_desc_ids
)
local_block_descs_ids
=
np
.
concatenate
(
local_descs_list
)
remote_block_descs_ids
=
np
.
concatenate
(
remote_descs_list
)
assert
len
(
local_block_descs_ids
)
==
len
(
remote_block_descs_ids
)
...
...
@@ -1219,14 +1254,14 @@ class NixlConnectorWorker:
def
_get_block_descs_ids
(
self
,
engine_id
:
str
,
block_ids
:
list
[
int
],
layer_idx
:
Optional
[
int
]
=
None
)
->
list
[
int
]
:
layer_idx
:
Optional
[
int
]
=
None
)
->
np
.
ndarray
:
"""
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
"""
if
layer_idx
is
None
:
region_ids
=
range
(
self
.
num_regions
)
region_ids
=
np
.
a
range
(
self
.
num_regions
)
else
:
assert
layer_idx
<
self
.
num_layers
if
self
.
num_layers
<
self
.
num_regions
:
...
...
@@ -1234,20 +1269,35 @@ class NixlConnectorWorker:
# the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i
assert
2
*
self
.
num_layers
==
self
.
num_regions
region_ids
=
range
(
2
*
layer_idx
,
2
*
layer_idx
+
2
)
region_ids
=
np
.
a
range
(
2
*
layer_idx
,
2
*
layer_idx
+
2
)
else
:
# Otherwise, we assume we have MLA and select i-th layer
assert
self
.
num_layers
==
self
.
num_regions
region_ids
=
range
(
layer_idx
,
layer_idx
+
1
)
region_ids
=
np
.
a
range
(
layer_idx
,
layer_idx
+
1
)
num_blocks
=
self
.
dst_num_blocks
[
engine_id
]
# Compute the desc ids for each block.
descs_ids
:
list
[
int
]
=
[]
for
reg_id
in
region_ids
:
for
block_id
in
block_ids
:
descs_ids
.
append
(
reg_id
*
num_blocks
+
block_id
)
return
descs_ids
region_ids
=
region_ids
[:,
None
]
block_ids
=
np
.
array
(
block_ids
)[
None
,
:]
descs_ids
=
region_ids
*
num_blocks
+
block_ids
return
descs_ids
.
flatten
()
def
get_backend_aware_kv_block_len
(
self
):
"""
Get the block length for one K/V element (K and V have the same size).
For FA and other backends, this is equal to the length of the whole
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
"""
if
self
.
_use_flashinfer
:
# For indexing only half (either just the K or V part).
block_len
=
self
.
block_len
//
2
else
:
block_len
=
self
.
block_len
return
block_len
@
contextlib
.
contextmanager
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
38d80967
...
...
@@ -15,7 +15,7 @@ import msgpack
import
torch
import
zmq
from
vllm.config
import
KVTransferConfig
from
vllm.config
.kv_transfer
import
KVTransferConfig
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
NCCLLibrary
,
buffer_type
,
cudaStream_t
,
ncclComm_t
,
ncclDataTypeEnum
)
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool
import
(
# noqa: E501
...
...
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
View file @
38d80967
...
...
@@ -3,7 +3,7 @@
import
hashlib
import
os
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Optional
import
safetensors
import
torch
...
...
@@ -238,7 +238,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
,
bool
]:
)
->
tuple
[
Optional
[
int
]
,
bool
]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
...
...
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
View file @
38d80967
...
...
@@ -13,7 +13,7 @@ import zmq
from
safetensors.torch
import
load
as
safetensors_load
from
safetensors.torch
import
save
as
safetensors_save
from
vllm.config
import
KVTransferConfig
from
vllm.config
.kv_transfer
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_pipe.base
import
KVPipeBase
from
vllm.logger
import
init_logger
from
vllm.utils
import
join_host_port
,
make_zmq_path
,
split_host_port
...
...
Prev
1
…
17
18
19
20
21
22
23
24
25
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment