Commit 3459e245 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc3' into v0.15.0rc3-dev

parents 6fc61a0d fe18ce4d
...@@ -184,15 +184,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con ...@@ -184,15 +184,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}' --kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
``` ```
### Cross layers blocks
By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred.
To enable this feature:
```bash
--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'
```
## Example Scripts/Code ## Example Scripts/Code
Refer to these example scripts in the vLLM repository: Refer to these example scripts in the vLLM repository:
......
...@@ -9,7 +9,7 @@ requires = [ ...@@ -9,7 +9,7 @@ requires = [
"torch == 2.9.0", "torch == 2.9.0",
"wheel", "wheel",
"jinja2", "jinja2",
"grpcio-tools>=1.76.0", "grpcio-tools",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
......
...@@ -9,5 +9,5 @@ wheel ...@@ -9,5 +9,5 @@ wheel
jinja2>=3.1.6 jinja2>=3.1.6
regex regex
build build
protobuf>=6.33.2 protobuf
grpcio-tools>=1.76.0 grpcio-tools
...@@ -9,7 +9,7 @@ blake3 ...@@ -9,7 +9,7 @@ blake3
py-cpuinfo py-cpuinfo
transformers >= 4.56.0, < 5 transformers >= 4.56.0, < 5
tokenizers >= 0.21.1 # Required for fast incremental detokenization. tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf >= 6.30.0 # Required by LlamaTokenizer, gRPC. protobuf # Required by LlamaTokenizer, gRPC.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp aiohttp
openai >= 1.99.1 # For Responses API with reasoning content openai >= 1.99.1 # For Responses API with reasoning content
...@@ -51,5 +51,5 @@ openai-harmony >= 0.0.3 # Required for gpt-oss ...@@ -51,5 +51,5 @@ openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic >= 0.71.0 anthropic >= 0.71.0
model-hosting-container-standards >= 0.1.13, < 1.0.0 model-hosting-container-standards >= 0.1.13, < 1.0.0
mcp mcp
grpcio>=1.76.0 grpcio
grpcio-reflection>=1.76.0 grpcio-reflection
\ No newline at end of file \ No newline at end of file
...@@ -34,18 +34,11 @@ else ...@@ -34,18 +34,11 @@ else
KV_CONFIG_HETERO_LAYOUT='' KV_CONFIG_HETERO_LAYOUT=''
fi fi
CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers
if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then
KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"cross_layers_blocks": "True"}'
else
KV_EXTRA_CONFIG=''
fi
# Build the kv-transfer-config once # Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
else else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
fi fi
# Models to run # Models to run
......
...@@ -18,12 +18,8 @@ import ray ...@@ -18,12 +18,8 @@ import ray
import torch import torch
from vllm import LLM from vllm import LLM
from vllm.config import KVTransferConfig, set_current_vllm_config from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
KVOutputAggregator,
TpKVTopology,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
...@@ -52,11 +48,8 @@ from vllm.sampling_params import SamplingParams ...@@ -52,11 +48,8 @@ from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup
from .utils import create_request, create_scheduler, create_vllm_config from .utils import create_request, create_scheduler, create_vllm_config
...@@ -373,7 +366,6 @@ def test_kv_transfer_handshake(dist_init): ...@@ -373,7 +366,6 @@ def test_kv_transfer_handshake(dist_init):
# Decode connector will be able to create handshake with the prefill connector. # Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
decode_connector.register_kv_caches(kv_caches)
# Here we are testing the retrieval of NIXLAgentMetadata. # Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent # Knowing the implementation detail, we override the add_remote_agent
...@@ -410,23 +402,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -410,23 +402,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self.kv_cache_layout = kv_cache_layout self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it. # Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1} self.src_xfer_handles_by_block_size = {self.block_size: 1}
test_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=test_shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
def _nixl_handshake( def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
...@@ -1395,7 +1370,6 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): ...@@ -1395,7 +1370,6 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
), ),
), ),
"TRITON_ATTN", "TRITON_ATTN",
"FLASHINFER",
], ],
) )
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
...@@ -1412,11 +1386,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ...@@ -1412,11 +1386,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
vllm_config = create_vllm_config(attention_backend=attn_backend) vllm_config = create_vllm_config(attention_backend=attn_backend)
# Enable cross layers blocks
vllm_config.kv_transfer_config.kv_connector_extra_config[
"enable_cross_layers_blocks"
] = True
# Import the appropriate backend based on the parameter # Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN": if attn_backend == "FLASH_ATTN":
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
...@@ -1426,99 +1395,11 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ...@@ -1426,99 +1395,11 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
backend_cls = RocmAttentionBackend backend_cls = RocmAttentionBackend
else: # TRITON else: # TRITON_ATTN
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
backend_cls = TritonAttentionBackend backend_cls = TritonAttentionBackend
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
# test run if not mocking.
mock_get_attn_backend.return_value = backend_cls
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
# Get the mock instance
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Appease NixlHandshakePayload encoding with some bytes
mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata"
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False
expected_tensor_size: int
expected_base_addrs: list[int]
expected_num_entries: int
kv_caches: dict[str, torch.Tensor]
if connector.prefer_cross_layer_blocks:
num_layers = 32
block_size = 16
num_blocks = 8
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=64,
dtype=torch.bfloat16,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["dummy-layer"],
)
for i in range(num_layers)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups=[],
)
with set_current_vllm_config(vllm_config):
_, cross_layers_kv_cache, _ = (
KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
kv_cache_config=kv_cache_config,
attn_groups=[
[
AttentionGroup(
backend=backend_cls,
layer_names=[],
kv_cache_spec=kv_cache_spec,
kv_cache_group_id=0,
)
]
],
cache_dtype=torch.bfloat16,
device=torch.cuda.current_device(),
kernel_block_sizes=[block_size],
)
)
# Store tensor info for validation
expected_tensor_size = (
cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel()
)
expected_base_addrs = [
cross_layers_kv_cache.data_ptr(),
]
expected_num_entries = 1
expected_blocks_count = 8
kv_caches = {"all-layers": cross_layers_kv_cache}
else:
# Create test kv cache tensors using proper backend shape # Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape( kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
...@@ -1539,9 +1420,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ...@@ -1539,9 +1420,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first: if is_blocks_first:
expected_tensor_size = ( expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
shared_tensor.element_size() * shared_tensor.numel()
)
expected_base_addrs = [ expected_base_addrs = [
shared_tensor.data_ptr(), shared_tensor.data_ptr(),
unique_tensor.data_ptr(), unique_tensor.data_ptr(),
...@@ -1558,7 +1437,34 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ...@@ -1558,7 +1437,34 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
unique_tensor[1].data_ptr(), unique_tensor[1].data_ptr(),
] ]
expected_num_entries = 4 expected_num_entries = 4
expected_blocks_count = 8
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
# test run if not mocking.
mock_get_attn_backend.return_value = backend_cls
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
# Get the mock instance
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Appease NixlHandshakePayload encoding with some bytes
mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata"
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False
# Execute register_kv_caches # Execute register_kv_caches
connector.register_kv_caches(kv_caches) connector.register_kv_caches(kv_caches)
...@@ -1583,14 +1489,11 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ...@@ -1583,14 +1489,11 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size # Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, ( assert len(blocks_data) == expected_blocks_count, (
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
) )
if connector.prefer_cross_layer_blocks:
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
else:
num_blocks = 2 num_blocks = 2
if is_blocks_first: if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2 expected_block_len = expected_tensor_size // num_blocks // 2
...@@ -2146,17 +2049,6 @@ def test_compatibility_hash_validation( ...@@ -2146,17 +2049,6 @@ def test_compatibility_hash_validation(
) )
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker decode_worker = decode_connector.connector_worker
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
decode_connector.register_kv_caches(kv_caches)
remote_config_params: dict[str, Any] = { remote_config_params: dict[str, Any] = {
"model": "facebook/opt-125m", "model": "facebook/opt-125m",
...@@ -2179,9 +2071,7 @@ def test_compatibility_hash_validation( ...@@ -2179,9 +2071,7 @@ def test_compatibility_hash_validation(
) )
) )
remote_hash = compute_nixl_compatibility_hash( remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, remote_vllm_config, decode_worker.backend_name
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
) )
prefill_block_size = config_overrides.get("block_size", 16) prefill_block_size = config_overrides.get("block_size", 16)
...@@ -2260,27 +2150,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2260,27 +2150,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker decode_worker = decode_connector.connector_worker
backend = get_current_attn_backend(local_vllm_config)
test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
decode_worker.kv_topo = TpKVTopology(
tp_rank=decode_worker.tp_rank,
engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backend=backend,
tensor_shape=test_shape,
)
decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
)
if error_scenario == "handshake_decode_error": if error_scenario == "handshake_decode_error":
msg_bytes = b"this is not valid msgpack data" msg_bytes = b"this is not valid msgpack data"
elif error_scenario == "handshake_validation_error": elif error_scenario == "handshake_validation_error":
......
...@@ -316,13 +316,12 @@ class TpKVTopology: ...@@ -316,13 +316,12 @@ class TpKVTopology:
attn_backend: type[AttentionBackend] attn_backend: type[AttentionBackend]
engine_id: EngineId engine_id: EngineId
remote_block_size: dict[EngineId, int] remote_block_size: dict[EngineId, int]
tensor_shape: torch.Size | None = None
def __post_init__(self): def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V # Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly. # or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape( kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=4, head_size=1 num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
) )
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below. # we just mock num_blocks to 1 for the dimension check below.
...@@ -330,32 +329,6 @@ class TpKVTopology: ...@@ -330,32 +329,6 @@ class TpKVTopology:
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
) )
self._kv_heads_position: int | None = None
self._cross_layers_blocks = False
if self.tensor_shape is not None:
self._cross_layers_blocks = (
len(self.tensor_shape) == len(kv_cache_shape) + 1
)
if self._cross_layers_blocks:
# prepend layers dimension
kv_cache_shape = (80,) + kv_cache_shape
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
# permute kv_cache_shape according to stride_order
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
physical_block_size_position = kv_cache_shape.index(16)
assert physical_block_size_position is not None
self._physical_block_size_position = -(
len(kv_cache_shape) - physical_block_size_position
)
@property @property
def is_kv_layout_blocks_first(self) -> bool: def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first return self._is_kv_layout_blocks_first
...@@ -363,9 +336,7 @@ class TpKVTopology: ...@@ -363,9 +336,7 @@ class TpKVTopology:
@property @property
def split_k_and_v(self) -> bool: def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present). # Whether to register regions for K and V separately (when present).
return not ( return not (self.is_mla or self.is_kv_layout_blocks_first)
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
)
@property @property
def tp_size(self) -> int: def tp_size(self) -> int:
...@@ -375,14 +346,6 @@ class TpKVTopology: ...@@ -375,14 +346,6 @@ class TpKVTopology:
def block_size(self) -> int: def block_size(self) -> int:
return self.remote_block_size[self.engine_id] return self.remote_block_size[self.engine_id]
@property
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks
@property
def block_size_position(self) -> int:
return self._physical_block_size_position
def tp_ratio( def tp_ratio(
self, self,
remote_tp_size: int, remote_tp_size: int,
......
...@@ -54,7 +54,7 @@ from vllm.forward_context import ForwardContext ...@@ -54,7 +54,7 @@ from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
...@@ -173,7 +173,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata): ...@@ -173,7 +173,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def compute_nixl_compatibility_hash( def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool vllm_config: VllmConfig, attn_backend_name: str
) -> str: ) -> str:
""" """
Compute compatibility hash for NIXL KV transfer. Compute compatibility hash for NIXL KV transfer.
...@@ -216,7 +216,6 @@ def compute_nixl_compatibility_hash( ...@@ -216,7 +216,6 @@ def compute_nixl_compatibility_hash(
# Attention backend and KV cache dtype affect memory layout # Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name, "attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype), "cache_dtype": str(cache_config.cache_dtype),
"cross_layers_blocks": cross_layers_blocks,
} }
compat_hash = hash_factors(factors) compat_hash = hash_factors(factors)
...@@ -299,19 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata): ...@@ -299,19 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1): class NixlConnector(KVConnectorBase_V1):
@property
def prefer_cross_layer_blocks(self) -> bool:
backend = get_current_attn_backend(self._vllm_config)
if backend.get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
):
# For now there is no benefit to run cross layers when backend
# does not support on HND
return False
extra_config = self.kv_transfer_config.kv_connector_extra_config
return bool(str(extra_config.get("enable_cross_layers_blocks", "False")))
def __init__( def __init__(
self, self,
...@@ -324,7 +310,6 @@ class NixlConnector(KVConnectorBase_V1): ...@@ -324,7 +310,6 @@ class NixlConnector(KVConnectorBase_V1):
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER: if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: NixlConnectorScheduler | None = ( self.connector_scheduler: NixlConnectorScheduler | None = (
...@@ -411,16 +396,6 @@ class NixlConnector(KVConnectorBase_V1): ...@@ -411,16 +396,6 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches) self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
self.connector_worker.register_kv_caches(kv_caches)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation) self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
...@@ -1002,17 +977,20 @@ class NixlConnectorWorker: ...@@ -1002,17 +977,20 @@ class NixlConnectorWorker:
# Get the attention backend from the first layer # Get the attention backend from the first layer
# NOTE (NickLucche) models with multiple backends are not supported yet # NOTE (NickLucche) models with multiple backends are not supported yet
self.attn_backend = get_current_attn_backend(vllm_config) backend = get_current_attn_backend(vllm_config)
self.backend_name = self.attn_backend.get_name() self.backend_name = backend.get_name()
self.kv_cache_layout = get_kv_cache_layout() self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout) logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
# lazy initialized in register_kv_caches self.compat_hash = compute_nixl_compatibility_hash(
self.compat_hash: str | None = None self.vllm_config, self.backend_name
self.kv_topo: TpKVTopology | None = None )
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
...@@ -1021,11 +999,16 @@ class NixlConnectorWorker: ...@@ -1021,11 +999,16 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats() self.xfer_stats = NixlKVConnectorStats()
self._physical_blocks_per_logical_kv_block = 1 self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( engine_id=self.engine_id,
"enforce_handshake_compat", True remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
) )
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake( def _nixl_handshake(
self, self,
...@@ -1040,7 +1023,6 @@ class NixlConnectorWorker: ...@@ -1040,7 +1023,6 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current # Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP, # local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i. # this happens to be the same single rank_i.
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size) p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {} remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port) path = make_zmq_path("tcp", host, port)
...@@ -1078,7 +1060,6 @@ class NixlConnectorWorker: ...@@ -1078,7 +1060,6 @@ class NixlConnectorWorker:
) )
# Check compatibility hash BEFORE decoding agent metadata # Check compatibility hash BEFORE decoding agent metadata
assert self.compat_hash is not None
if ( if (
self.enforce_compat_hash self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash and handshake_payload.compatibility_hash != self.compat_hash
...@@ -1287,20 +1268,6 @@ class NixlConnectorWorker: ...@@ -1287,20 +1268,6 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl.""" """Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=next(iter(kv_caches.values())).shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
if self.use_host_buffer: if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches) self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), ( assert len(self.host_xfer_buffers) == len(kv_caches), (
...@@ -1335,21 +1302,29 @@ class NixlConnectorWorker: ...@@ -1335,21 +1302,29 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB). # (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region # Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim). # to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None tensor_size_bytes = None
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
if self.device_type == "cpu":
block_size_position = -2
else:
block_size_position = -2 if self.use_mla else -3
# Enable different block lengths for different layers when MLA is used. # Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]() self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items(): for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = ( cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches]
)
for cache in cache_list: for cache in cache_list:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
if base_addr in seen_base_addresses: if base_addr in seen_base_addresses:
continue continue
kernel_block_size = cache.shape[self.kv_topo.block_size_position] kernel_block_size = cache.shape[block_size_position]
if self.block_size != kernel_block_size: if self.block_size != kernel_block_size:
logger.info_once( logger.info_once(
"User-specified logical block size (%s) does not match" "User-specified logical block size (%s) does not match"
...@@ -1411,7 +1386,6 @@ class NixlConnectorWorker: ...@@ -1411,7 +1386,6 @@ class NixlConnectorWorker:
self.device_kv_caches = kv_caches self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks self.dst_num_blocks[self.engine_id] = self.num_blocks
if self.kv_topo.is_kv_layout_blocks_first: if self.kv_topo.is_kv_layout_blocks_first:
for i in range(len(self.slot_size_per_layer)): for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0 assert self.slot_size_per_layer[i] % 2 == 0
...@@ -1467,7 +1441,6 @@ class NixlConnectorWorker: ...@@ -1467,7 +1441,6 @@ class NixlConnectorWorker:
block_size=self.block_size, block_size=self.block_size,
) )
# Wrap metadata in payload with hash for defensive decoding # Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
encoder = msgspec.msgpack.Encoder() encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload( self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash, compatibility_hash=self.compat_hash,
...@@ -1489,8 +1462,6 @@ class NixlConnectorWorker: ...@@ -1489,8 +1462,6 @@ class NixlConnectorWorker:
register another local_xfer_handler using remote block len to ensure register another local_xfer_handler using remote block len to ensure
data copy correctness. data copy correctness.
""" """
assert self.kv_topo is not None
block_size_ratio = self.block_size // block_size block_size_ratio = self.block_size // block_size
blocks_data = [] blocks_data = []
for i, base_addr in enumerate(self.seen_base_addresses): for i, base_addr in enumerate(self.seen_base_addresses):
...@@ -1603,7 +1574,6 @@ class NixlConnectorWorker: ...@@ -1603,7 +1574,6 @@ class NixlConnectorWorker:
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:| 0| 1| 8| 12| # local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
if engine_id not in self.dst_num_blocks: if engine_id not in self.dst_num_blocks:
...@@ -1731,10 +1701,7 @@ class NixlConnectorWorker: ...@@ -1731,10 +1701,7 @@ class NixlConnectorWorker:
""" """
remote_engine_id = nixl_agent_meta.engine_id remote_engine_id = nixl_agent_meta.engine_id
assert ( assert self._tp_size[remote_engine_id] == remote_tp_size
self._tp_size[remote_engine_id] == remote_tp_size
and self.kv_topo is not None
)
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
...@@ -1871,7 +1838,6 @@ class NixlConnectorWorker: ...@@ -1871,7 +1838,6 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0: if len(self.device_kv_caches) == 0:
return return
assert block_size_ratio >= 1, "Only nP < nD supported currently." assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1: if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug( logger.debug(
"Post-processing device kv cache on receive by converting " "Post-processing device kv cache on receive by converting "
...@@ -1891,7 +1857,7 @@ class NixlConnectorWorker: ...@@ -1891,7 +1857,7 @@ class NixlConnectorWorker:
block_size_ratio, block_size_ratio,
) )
split_k_and_v = self.kv_topo.split_k_and_v split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
for block_ids in block_ids_list: for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
...@@ -1916,7 +1882,6 @@ class NixlConnectorWorker: ...@@ -1916,7 +1882,6 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done. to track which workers are done.
""" """
assert self.kv_topo is not None
done_sending = self._get_new_notifs() done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers) done_recving = self._pop_done_transfers(self._recving_transfers)
...@@ -1986,7 +1951,6 @@ class NixlConnectorWorker: ...@@ -1986,7 +1951,6 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling. for all consumers to be done pulling.
""" """
assert self.kv_topo is not None
notified_req_ids: set[str] = set() notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values(): for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs: for notif in notifs:
...@@ -2146,7 +2110,7 @@ class NixlConnectorWorker: ...@@ -2146,7 +2110,7 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None and self.kv_topo is not None assert meta.remote is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id meta.remote.engine_id
) )
...@@ -2215,7 +2179,10 @@ class NixlConnectorWorker: ...@@ -2215,7 +2179,10 @@ class NixlConnectorWorker:
local_xfer_side_handle: int, local_xfer_side_handle: int,
remote_xfer_side_handle: int, remote_xfer_side_handle: int,
): ):
assert self.kv_topo is not None """
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1: if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks( local_block_ids = self.get_mapped_blocks(
...@@ -2448,7 +2415,6 @@ class NixlConnectorWorker: ...@@ -2448,7 +2415,6 @@ class NixlConnectorWorker:
For FlashInfer, this is half the length of the whole block, as K and V For FlashInfer, this is half the length of the whole block, as K and V
share the same region. share the same region.
""" """
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first: if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part). # For indexing only half (either just the K or V part).
block_len = self.block_len_per_layer[layer_idx] // 2 block_len = self.block_len_per_layer[layer_idx] // 2
......
...@@ -308,6 +308,26 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -308,6 +308,26 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
) )
supports_update_block_table: bool = True supports_update_block_table: bool = True
@classmethod
def get_cudagraph_support(
cls,
vllm_config: "VllmConfig",
kv_cache_spec: "AttentionSpec",
) -> AttentionCGSupport:
# FA2 does not support CUDA graphs with encoder-decoder models due to
# accuracy issues reported in https://github.com/vllm-project/vllm/issues/33091
if (
vllm_config.model_config.is_encoder_decoder
and get_flash_attn_version() == 2
):
logger.warning_once(
"FlashAttention2 does not support CUDA graphs with "
"encoder-decoder models due to accuracy issues reported in #33091. "
"Disabling CUDA graph."
)
return AttentionCGSupport.NEVER
return cls._cudagraph_support
def __init__( def __init__(
self, self,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
......
...@@ -911,6 +911,17 @@ class EngineCoreProc(EngineCore): ...@@ -911,6 +911,17 @@ class EngineCoreProc(EngineCore):
set_process_title("EngineCore") set_process_title("EngineCore")
decorate_logs() decorate_logs()
if data_parallel and vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug(
"Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id,
)
parallel_config.data_parallel_index = dp_rank parallel_config.data_parallel_index = dp_rank
if data_parallel and vllm_config.model_config.is_moe: if data_parallel and vllm_config.model_config.is_moe:
# Set data parallel rank for this engine process. # Set data parallel rank for this engine process.
...@@ -1285,17 +1296,6 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1285,17 +1296,6 @@ class DPEngineCoreProc(EngineCoreProc):
assert local_dp_rank is not None assert local_dp_rank is not None
assert 0 <= local_dp_rank <= dp_rank < dp_size assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug(
"Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id,
)
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
......
...@@ -313,6 +313,13 @@ class CoreEngineActorManager: ...@@ -313,6 +313,13 @@ class CoreEngineActorManager:
dp_vllm_config.parallel_config.placement_group = pg dp_vllm_config.parallel_config.placement_group = pg
local_client = index < local_engine_count local_client = index < local_engine_count
if dp_size > 1 and dp_vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
dp_vllm_config.kv_transfer_config.engine_id = (
f"{dp_vllm_config.kv_transfer_config.engine_id}_dp{local_index}"
)
# Ray XPU known issue: dpctl initializes the GPU runtime early, so # Ray XPU known issue: dpctl initializes the GPU runtime early, so
# setting device env vars in Ray actor's initialization method # setting device env vars in Ray actor's initialization method
# will not affect device selection. See: # will not affect device selection. See:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment