Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
......@@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = {
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: 2 * MiB, # 2 MB
6: 1 * MiB, # 1 MB
8: 1 * MiB, # 1 MB
}
}
......
......@@ -252,7 +252,10 @@ class DeviceCommunicatorBase:
moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module)
......
......@@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
symm_mem_enabled=(self.symm_mem_comm is not None
and not self.symm_mem_comm.disabled),
)
if current_platform.is_rocm():
......@@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
......
......@@ -54,13 +54,14 @@ class CustomAllreduce:
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024) -> None:
max_size=8192 * 1024,
symm_mem_enabled=False) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
it will be bound to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
......@@ -111,7 +112,7 @@ class CustomAllreduce:
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
if (current_platform.is_cuda() and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
......@@ -159,7 +160,7 @@ class CustomAllreduce:
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# Metadata composes of two parts: metadata for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.platforms import current_platform
if current_platform.is_neuron():
import torch_xla.core.xla_model as xm
class NeuronCommunicator(DeviceCommunicatorBase):
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
......@@ -31,7 +31,7 @@ class PyNcclCommunicator:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
it will be bound to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
......
......@@ -78,7 +78,7 @@ class QuickAllReduce:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
it will be bound to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
......
......@@ -186,7 +186,7 @@ class RayPPCommunicator(Communicator):
"""
Receive a torch.Tensor from a peer and synchronize the current stream.
After this call returns, the receive buffer is safe to read from from
After this call returns, the receive buffer is safe to read from
any stream. An RayChannelError will be raised if an error occurred
(e.g., remote actor died), and the buffer is not safe to read.
......
......@@ -27,8 +27,13 @@ class SymmMemCommunicator:
"10.0": [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
# add options for testing
force_multimem: Optional[bool] = None,
max_size_override: Optional[int] = None):
self.disabled = True
if not symm_mem_available:
......@@ -64,8 +69,17 @@ class SymmMemCommunicator:
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
# Use override max_size if provided, otherwise use default
if max_size_override is not None:
self.max_size = max_size_override
logger.info(
"SymmMemCommunicator: Using override max_size: %s bytes",
self.max_size,
)
else:
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability][self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
......@@ -76,6 +90,7 @@ class SymmMemCommunicator:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.force_multimem = force_multimem
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
......@@ -98,8 +113,18 @@ class SymmMemCommunicator:
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
# Determine which algorithm to use
use_multimem = False
if self.force_multimem is not None:
# Test override: use forced setting
use_multimem = self.force_multimem
else:
# Normal logic: use multimem for supported world sizes
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]
if use_multimem:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
......
......@@ -14,8 +14,9 @@ from typing import Any, Callable, Optional, Union
import msgspec
import zmq
from vllm.config import KVEventsConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
logger = init_logger(__name__)
......@@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
block_hashes: list[ExternalBlockHash]
parent_block_hash: Optional[ExternalBlockHash]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
......@@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
block_hashes: list[ExternalBlockHash]
medium: Optional[str]
......
......@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_transfer_state import (
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
has_kv_transfer_group, is_v1_kv_transfer_group)
KVConnectorBaseType, ensure_kv_transfer_initialized,
ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group)
__all__ = [
"get_kv_transfer_group", "has_kv_transfer_group",
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
"KVConnectorBaseType"
"ensure_kv_transfer_shutdown", "KVConnectorBaseType"
]
......@@ -14,7 +14,8 @@ from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING:
from vllm.config import KVTransferConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
logger = init_logger(__name__)
......
......@@ -6,7 +6,7 @@ KV cache helper for store.
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import Optional, cast
from typing import Literal, Optional, Union, cast
import torch
......@@ -196,3 +196,51 @@ class KVOutputAggregator:
output_future.add_done_callback(make_callback(i))
return result_future
def _make_src_and_dst_indices(
src_block_ids: list[int],
dst_block_ids: list[int],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> tuple[torch.Tensor, torch.Tensor]:
src_indices = torch.tensor(src_block_ids,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_block_ids,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
def copy_kv_blocks(
src_kv_caches: dict[str, torch.Tensor],
dst_kv_caches: dict[str, torch.Tensor],
src_block_ids: list[int],
dst_block_ids: list[int],
direction: Literal["h2d", "d2h"],
) -> None:
"""Copy kv blocks between different buffers."""
if not src_kv_caches or not dst_kv_caches or \
not src_block_ids or not dst_block_ids or \
len(src_block_ids) != len(dst_block_ids):
return
src_device = next(iter(src_kv_caches.values())).device
dst_device = next(iter(dst_kv_caches.values())).device
src_indices, dst_indices = _make_src_and_dst_indices(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_device=src_device,
dst_device=dst_device)
from vllm.platforms import current_platform
if direction == "h2d":
copy_fn = current_platform.insert_blocks_to_device
else:
copy_fn = current_platform.swap_out_blocks_to_host
for layer_name in src_kv_caches:
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
......@@ -226,6 +226,14 @@ class KVConnectorBase_V1(ABC):
"""
return None, None
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================
......@@ -235,7 +243,7 @@ class KVConnectorBase_V1(ABC):
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
......@@ -247,8 +255,11 @@ class KVConnectorBase_V1(ABC):
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
......
......@@ -110,7 +110,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
......
......@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import KVTransferConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
......@@ -87,6 +88,18 @@ class MultiConnector(KVConnectorBase_V1):
for c in self._connectors:
c.clear_connector_metadata()
def shutdown(self):
exception: Optional[Exception] = None
for c in self._connectors:
try:
c.shutdown()
except Exception as e:
logger.exception("Exception during connector %s shutdown.",
c.__class__.__name__)
exception = e
if exception:
raise exception
# ==============================
# Worker-side methods
# ==============================
......@@ -142,11 +155,15 @@ class MultiConnector(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
) -> tuple[Optional[int], bool]:
to_return = (0, False)
for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens)
# If there is a connector still looking up the matches,
# we return None to indicate that we are not done yet.
if toks is None:
return (None, False)
# The first connector that has new matched tokens will be assigned
# to this request.
if to_return[0] == 0 and toks > 0:
......
......@@ -14,6 +14,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import msgspec
import numpy as np
import torch
import zmq
......@@ -60,6 +61,7 @@ except ImportError:
_NIXL_SUPPORTED_XPUS = {
"cuda": ("cuda", ),
"tpu": ("cpu", ),
"xpu": ("cpu", ),
}
......@@ -160,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1):
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
num_computed_tokens: int) -> tuple[Optional[int], bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
......@@ -715,7 +717,7 @@ class NixlConnectorWorker:
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = not (self.use_mla or self._use_pallas_v1
or self._use_flashinfer)
......@@ -758,12 +760,21 @@ class NixlConnectorWorker:
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence we double the number
# of 'virtual' regions here and halve `block_len` below.
self.num_regions *= 2
kv_block_len = self.get_backend_aware_kv_block_len()
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in seen_base_addresses:
......@@ -776,8 +787,18 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# (addr, len, device id)
# TODO: does device_id matter to DRAM?
blocks_data.append((addr, self.block_len, self.tp_rank))
blocks_data.append((addr, kv_block_len, self.tp_rank))
if self._use_flashinfer:
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
......@@ -787,7 +808,7 @@ class NixlConnectorWorker:
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
......@@ -903,7 +924,7 @@ class NixlConnectorWorker:
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)
if self._use_flashinfer:
# Account for joint KV in FlashInfer.
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout.
......@@ -929,10 +950,10 @@ class NixlConnectorWorker:
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Only register the remote's descriptors if current rank pulls from it.
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
kv_block_len = self.get_backend_aware_kv_block_len()
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
......@@ -943,7 +964,16 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
blocks_data.append((addr, kv_block_len, remote_tp_rank))
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
......@@ -1163,8 +1193,8 @@ class NixlConnectorWorker:
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
local_block_descs_ids: list[int] = []
remote_block_descs_ids: list[int] = []
local_block_descs_ids: np.ndarray
remote_block_descs_ids: np.ndarray
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
......@@ -1174,6 +1204,8 @@ class NixlConnectorWorker:
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
local_descs_list = []
remote_descs_list = []
for layer_idx, block_window in enumerate(
self.block_window_per_layer):
# For each layer:
......@@ -1193,8 +1225,11 @@ class NixlConnectorWorker:
layer_remote_desc_ids = self._get_block_descs_ids(
dst_engine_id, layer_remote_block_ids, layer_idx)
local_block_descs_ids.extend(layer_local_desc_ids)
remote_block_descs_ids.extend(layer_remote_desc_ids)
local_descs_list.append(layer_local_desc_ids)
remote_descs_list.append(layer_remote_desc_ids)
local_block_descs_ids = np.concatenate(local_descs_list)
remote_block_descs_ids = np.concatenate(remote_descs_list)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
......@@ -1219,14 +1254,14 @@ class NixlConnectorWorker:
def _get_block_descs_ids(self,
engine_id: str,
block_ids: list[int],
layer_idx: Optional[int] = None) -> list[int]:
layer_idx: Optional[int] = None) -> np.ndarray:
"""
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
"""
if layer_idx is None:
region_ids = range(self.num_regions)
region_ids = np.arange(self.num_regions)
else:
assert layer_idx < self.num_layers
if self.num_layers < self.num_regions:
......@@ -1234,20 +1269,35 @@ class NixlConnectorWorker:
# the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i
assert 2 * self.num_layers == self.num_regions
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
else:
# Otherwise, we assume we have MLA and select i-th layer
assert self.num_layers == self.num_regions
region_ids = range(layer_idx, layer_idx + 1)
region_ids = np.arange(layer_idx, layer_idx + 1)
num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block.
descs_ids: list[int] = []
for reg_id in region_ids:
for block_id in block_ids:
descs_ids.append(reg_id * num_blocks + block_id)
return descs_ids
region_ids = region_ids[:, None]
block_ids = np.array(block_ids)[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
def get_backend_aware_kv_block_len(self):
"""
Get the block length for one K/V element (K and V have the same size).
For FA and other backends, this is equal to the length of the whole
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
"""
if self._use_flashinfer:
# For indexing only half (either just the K or V part).
block_len = self.block_len // 2
else:
block_len = self.block_len
return block_len
@contextlib.contextmanager
......
......@@ -15,7 +15,7 @@ import msgpack
import torch
import zmq
from vllm.config import KVTransferConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
......
......@@ -3,7 +3,7 @@
import hashlib
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import safetensors
import torch
......@@ -238,7 +238,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
......
......@@ -13,7 +13,7 @@ import zmq
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config import KVTransferConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
from vllm.utils import join_host_port, make_zmq_path, split_host_port
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment