Unverified Commit 96b23b8e authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Bugfix][Nixl] Fix kernel physical<>logical block_size issue (#28677)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 433c0f86
...@@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization(): ...@@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
req_index = 0 req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index) block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification. # Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks( expected_kernel_blocks = block_table.map_to_kernel_blocks(
np.array(kvcache_manager_blocks) np.array(kvcache_manager_blocks),
block_table.blocks_per_kv_block,
block_table._kernel_block_arange,
) )
# Verify block table state # Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks) assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
......
...@@ -49,6 +49,7 @@ from vllm.platforms import current_platform ...@@ -49,6 +49,7 @@ from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): ...@@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
@dataclass @dataclass
class ReqMeta: class ReqMeta:
local_block_ids: list[int] local_block_ids: list[int]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
remote_block_ids: list[int] remote_block_ids: list[int]
remote_host: str remote_host: str
remote_port: int remote_port: int
...@@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): ...@@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
assert load_remote_cache ^ save_to_host assert load_remote_cache ^ save_to_host
_req = ReqMeta( _req = ReqMeta(
local_block_ids=local_block_ids, local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"], remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"], remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"], remote_host=kv_transfer_params["remote_host"],
...@@ -935,6 +939,7 @@ class NixlConnectorWorker: ...@@ -935,6 +939,7 @@ class NixlConnectorWorker:
attn_backend=backend, attn_backend=backend,
) )
self._use_pallas = self.kv_topo._use_pallas self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake( def _nixl_handshake(
self, self,
...@@ -1133,6 +1138,22 @@ class NixlConnectorWorker: ...@@ -1133,6 +1138,22 @@ class NixlConnectorWorker:
if base_addr in seen_base_addresses: if base_addr in seen_base_addresses:
continue continue
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. ",
self.block_size,
kernel_block_size,
)
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
seen_base_addresses.append(base_addr) seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size() curr_tensor_size_bytes = cache.numel() * cache.element_size()
...@@ -1479,7 +1500,7 @@ class NixlConnectorWorker: ...@@ -1479,7 +1500,7 @@ class NixlConnectorWorker:
assert self.use_host_buffer assert self.use_host_buffer
assert self.copy_blocks is not None assert self.copy_blocks is not None
local_block_ids = meta.local_block_ids local_block_ids = meta.local_physical_block_ids
self.copy_blocks( self.copy_blocks(
self.host_xfer_buffers, self.host_xfer_buffers,
self.device_kv_caches, self.device_kv_caches,
...@@ -1492,7 +1513,7 @@ class NixlConnectorWorker: ...@@ -1492,7 +1513,7 @@ class NixlConnectorWorker:
"synced recved kv of request[%s] to device kv buffer," "synced recved kv of request[%s] to device kv buffer,"
"local_block_ids: %s. ", "local_block_ids: %s. ",
req_id, req_id,
",".join(map(str, meta.local_block_ids)), ",".join(map(str, local_block_ids)),
) )
def save_kv_to_host(self, metadata: NixlConnectorMetadata): def save_kv_to_host(self, metadata: NixlConnectorMetadata):
...@@ -1501,19 +1522,22 @@ class NixlConnectorWorker: ...@@ -1501,19 +1522,22 @@ class NixlConnectorWorker:
assert self.copy_blocks is not None assert self.copy_blocks is not None
for req_id, meta in metadata.reqs_to_save.items(): for req_id, meta in metadata.reqs_to_save.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug( logger.debug(
"save_load_kv for request[%s] to host xfer buffer." "save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. ", "local_block_ids: %s. ",
req_id, req_id,
",".join(map(str, meta.local_block_ids)), ",".join(map(str, meta.local_physical_block_ids)),
) )
# blocking # blocking
self.copy_blocks( self.copy_blocks(
self.device_kv_caches, self.device_kv_caches,
self.host_xfer_buffers, self.host_xfer_buffers,
meta.local_block_ids, meta.local_physical_block_ids,
meta.local_block_ids, meta.local_physical_block_ids,
"d2h", "d2h",
) )
...@@ -1582,7 +1606,7 @@ class NixlConnectorWorker: ...@@ -1582,7 +1606,7 @@ class NixlConnectorWorker:
if self.use_host_buffer: if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta) self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv: if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_block_ids block_ids_to_permute += meta.local_physical_block_ids
if len(block_ids_to_permute) > 0: if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute) self.permute_device_kv(block_ids_to_permute)
...@@ -1669,7 +1693,7 @@ class NixlConnectorWorker: ...@@ -1669,7 +1693,7 @@ class NixlConnectorWorker:
req_id, req_id,
xfer_state, xfer_state,
) )
# mark all blocks for this request as invalid # mark all (logical)blocks for this request as invalid
if meta := self._recving_metadata.pop(req_id, None): if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids) self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None) self._recving_metadata.pop(req_id, None)
...@@ -1686,13 +1710,19 @@ class NixlConnectorWorker: ...@@ -1686,13 +1710,19 @@ class NixlConnectorWorker:
We check for these trnxs to complete in each step(). We check for these trnxs to complete in each step().
""" """
for req_id, meta in metadata.reqs_to_recv.items(): for req_id, meta in metadata.reqs_to_recv.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
meta.remote_block_ids = self._logical_to_kernel_block_ids(
meta.remote_block_ids
)
remote_engine_id = meta.remote_engine_id remote_engine_id = meta.remote_engine_id
logger.debug( logger.debug(
"start_load_kv for request %s from remote engine %s. " "start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ", "Num local_block_ids: %s. Num remote_block_ids: %s. ",
req_id, req_id,
remote_engine_id, remote_engine_id,
len(meta.local_block_ids), len(meta.local_physical_block_ids),
len(meta.remote_block_ids), len(meta.remote_block_ids),
) )
# always store metadata for failure recovery # always store metadata for failure recovery
...@@ -1740,7 +1770,7 @@ class NixlConnectorWorker: ...@@ -1740,7 +1770,7 @@ class NixlConnectorWorker:
self._read_blocks( self._read_blocks(
request_id=req_id, request_id=req_id,
dst_engine_id=meta.remote_engine_id, dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids, local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids, remote_block_ids=meta.remote_block_ids,
) )
...@@ -1867,7 +1897,7 @@ class NixlConnectorWorker: ...@@ -1867,7 +1897,7 @@ class NixlConnectorWorker:
"Marking blocks as invalid.", "Marking blocks as invalid.",
request_id, request_id,
) )
# mark all blocks for this request as invalid # mark all (logical) blocks for this request as invalid
if meta := self._recving_metadata.get(request_id): if meta := self._recving_metadata.get(request_id):
self._invalid_block_ids.update(meta.local_block_ids) self._invalid_block_ids.update(meta.local_block_ids)
self.xfer_stats.record_failed_transfer() self.xfer_stats.record_failed_transfer()
...@@ -1906,6 +1936,23 @@ class NixlConnectorWorker: ...@@ -1906,6 +1936,23 @@ class NixlConnectorWorker:
descs_ids = region_ids * num_blocks + block_ids descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten() return descs_ids.flatten()
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
does not match the one required by the attn backend.
"""
if self._physical_blocks_per_logical_kv_block == 1:
# Noop when physical and logical block sizes are the same
return block_ids
block_ids_np = np.array(block_ids)
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
return BlockTable.map_to_kernel_blocks(
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist()
def get_backend_aware_kv_block_len(self, layer_idx: int): def get_backend_aware_kv_block_len(self, layer_idx: int):
""" """
Get the block length for one K/V element (K and V have the same size). Get the block length for one K/V element (K and V have the same size).
......
...@@ -98,7 +98,9 @@ class BlockTable: ...@@ -98,7 +98,9 @@ class BlockTable:
return return
if self.use_hybrid_blocks: if self.use_hybrid_blocks:
block_ids = self._map_to_kernel_blocks(np.array(block_ids)) block_ids = self.map_to_kernel_blocks(
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
)
num_blocks = len(block_ids) num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx] start = self.num_blocks_per_row[row_idx]
...@@ -188,7 +190,12 @@ class BlockTable: ...@@ -188,7 +190,12 @@ class BlockTable:
self.block_table.gpu.fill_(0) self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0) self.block_table.cpu.fill_(0)
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray: @staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_arange: np.ndarray,
) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs. """Convert kv_manager_block_id IDs to kernel block IDs.
Example: Example:
...@@ -203,12 +210,12 @@ class BlockTable: ...@@ -203,12 +210,12 @@ class BlockTable:
# kv_manager_block_id 1 → kernel block id [2, 3] # kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5] # kv_manager_block_id 2 → kernel block id [4, 5]
""" """
if not self.use_hybrid_blocks: if blocks_per_kv_block == 1:
return kv_manager_block_ids return kv_manager_block_ids
kernel_block_ids = ( kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ self._kernel_block_arange + kernel_block_arange
) )
return kernel_block_ids.reshape(-1) return kernel_block_ids.reshape(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment