Unverified Commit 07cfc3a1 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: kvbm + connector (#2258)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
Co-authored-by: default avatarOlga Andreeva <oandreeva@nvidia.com>
Co-authored-by: default avatarZiqi Fan <ziqif@nvidia.com>
Co-authored-by: default avatarJohn Thompson <jothomson@nvidia.com>
Co-authored-by: default avatarRichard Huo <rihuo@nvidia.com>
Co-authored-by: default avatarZicheng Ma <zichengm@nvidia.com>
parent bf5862a1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Loader for the Rust-based vLLM integration objects.
"""
try:
from dynamo._core import _vllm_integration
# Runtime - dynamically loaded classes from Rust extension
KvbmCacheManager = getattr(_vllm_integration, "KvbmCacheManager")
KvbmRequest = getattr(_vllm_integration, "KvbmRequest")
KvbmBlockList = getattr(_vllm_integration, "KvbmBlockList")
BlockState = getattr(_vllm_integration, "BlockState")
BlockStates = getattr(_vllm_integration, "BlockStates")
SlotUpdate = getattr(_vllm_integration, "SlotUpdate")
KvConnectorWorker = getattr(_vllm_integration, "PyKvConnectorWorker")
KvConnectorLeader = getattr(_vllm_integration, "PyKvConnectorLeader")
SchedulerOutput = getattr(_vllm_integration, "SchedulerOutput")
from dynamo.llm import BlockManager
except ImportError:
print("Failed to import Dynamo KVBM. vLLM integration will not be available.")
KvbmCacheManager = None
KvbmRequest = None
KvbmBlockList = None
BlockState = None
BlockStates = None
SlotUpdate = None
BlockManager = None
KvConnectorWorker = None
KvConnectorLeader = None
SchedulerOutput = None
__all__ = [
"KvbmCacheManager",
"KvbmRequest",
"KvbmBlockList",
"BlockState",
"BlockStates",
"SlotUpdate",
"BlockManager",
"KvConnectorWorker",
"KvConnectorLeader",
"SchedulerOutput",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pytest
import torch
# Attempt to import the optional module
try:
from dynamo.llm import BlockManager
except ImportError:
pytest.importorskip(
"optional_module", reason="block-manager feature is not enabled"
)
pytestmark = pytest.mark.pre_merge
WORKER_ID = 0
NUM_LAYER = 5
OUTER_DIM = 2
PAGE_SIZE = 4
INNER_DIM = 13
DTYPE, TORCH_DTYPE = "FP32", torch.float32
HOST_NUM_BLOCKS = 16
DEVICE_NUM_BLOCKS = 16
DEVICE_ID = 0
def new_block_manager():
return BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
@pytest.fixture
def block_manager():
return new_block_manager()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_manager_initialization():
# Python should drop the BlockManager instance as soon as it goes out of scope, but
# it may not be garbage collected immediately, depending on the garbage collector.
BlockManager(WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM)
BlockManager(WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM, DTYPE)
BlockManager(
WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM, DTYPE, HOST_NUM_BLOCKS
)
BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
device_num_blocks=DEVICE_NUM_BLOCKS,
)
BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
)
BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
device_num_blocks=DEVICE_NUM_BLOCKS,
device_id=DEVICE_ID,
)
BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_cpu_block_access(block_manager: BlockManager):
block_count = 2
block_list = block_manager.allocate_host_blocks_blocking(block_count)
blocks = block_list.to_list()
assert len(blocks) == block_count
tensors = [torch.from_dlpack(b) for b in blocks]
for tensor in tensors:
assert tensor.get_device() == -1 # CPU
assert tensor.shape == (1, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM)
assert tensor.dtype == TORCH_DTYPE
# print(tensors)
for tensor in tensors:
tensor[0][0][0][0][0] = 1.0
tensor[0][NUM_LAYER - 1][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0
# print(tensors)
blocks_ = block_list.to_list()
assert blocks is not blocks_
assert len(blocks) == len(blocks_)
tensors_ = [torch.from_dlpack(b) for b in blocks_]
for tensor, tensor_ in zip(tensors, tensors_):
assert tensor is not tensor_
assert tensor.shape == tensor_.shape
assert tensor.dtype == tensor_.dtype
assert torch.allclose(tensor, tensor_)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_gpu_block_access(block_manager: BlockManager):
block_count = 6
block_list = block_manager.allocate_device_blocks_blocking(block_count)
blocks = block_list.to_list()
assert len(blocks) == block_count
tensors = [torch.from_dlpack(b) for b in blocks]
for tensor in tensors:
assert tensor.get_device() == DEVICE_ID # GPU
assert tensor.shape == (1, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM)
assert tensor.dtype == TORCH_DTYPE
# print(tensors)
for tensor in tensors:
tensor[0][0][0][0][0] = 1.0
tensor[0][NUM_LAYER - 1][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0
# print(tensors)
blocks_ = block_list.to_list()
assert blocks is not blocks_
assert len(blocks) == len(blocks_)
tensors_ = [torch.from_dlpack(b) for b in blocks_]
for tensor, tensor_ in zip(tensors, tensors_):
assert tensor is not tensor_
assert tensor.shape == tensor_.shape
assert tensor.dtype == tensor_.dtype
assert torch.allclose(tensor, tensor_)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_list_iteration(block_manager: BlockManager):
block_count = 4
block_list = await block_manager.allocate_host_blocks(block_count)
# Test __len__()
assert len(block_list) == block_count
# Test __getitem__()
for i in range(block_count):
block = block_list[i]
tensor = torch.from_dlpack(block)
tensor[0][0][0][0][0] = 1.0 + i
# Test __iter__() and __next__()
idx = 1.0
for block in block_list:
tensor = torch.from_dlpack(block)
assert tensor[0][0][0][0][0] == idx
tensor[0][0][0][0][0] += 0.5
idx += 1.0
assert idx == 1.0 + block_count
# Test __iter__() should reset current index
idx = 1.0
for block in block_list:
tensor = torch.from_dlpack(block)
assert tensor[0][0][0][0][0] == idx + 0.5
idx += 1.0
assert idx == 1.0 + block_count
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_copy_g1_g2(block_manager: BlockManager):
# Allocate device (G1) and host (G2) block
host_block_list = await block_manager.allocate_host_blocks(1)
device_block_list = await block_manager.allocate_device_blocks(1)
# Populate host block with unique values
host_tensor = torch.from_dlpack(host_block_list[0])
for i in range(NUM_LAYER):
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
host_tensor[0][i][j][k][w] = (
i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
# Copy host block to device block after permuting
permute_dims = (0, 2, 4, 3, 1)
device_tensor_ = torch.from_dlpack(device_block_list[0]).permute(*permute_dims)
device_tensor_.copy_(host_tensor.permute(*permute_dims))
# Assert device block is contiguous and updated in block manager
device_tensor = torch.from_dlpack(device_block_list[0])
for i in range(NUM_LAYER):
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
assert (
device_tensor[0][i][j][k][w]
== i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
# Set host block to zero and assert updated in block manager
host_tensor_ = torch.from_dlpack(host_block_list[0]).permute(*permute_dims)
host_tensor_.zero_()
assert torch.all(host_tensor == 0)
# Copy device block back to host block
host_tensor_.copy_(device_tensor_)
# Assert host block is updated in block manager
for i in range(NUM_LAYER):
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
assert (
host_tensor[0][i][j][k][w]
== i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_cpu_layer_access(block_manager: BlockManager):
block_list = block_manager.allocate_host_blocks_blocking(1)
block = block_list[0]
layers = block.to_list()
assert len(layers) == NUM_LAYER
tensors = [torch.from_dlpack(bl) for bl in layers]
for tensor in tensors:
assert tensor.get_device() == -1 # CPU
assert tensor.shape == (1, 1, OUTER_DIM, PAGE_SIZE, INNER_DIM)
assert tensor.dtype == TORCH_DTYPE
# print(tensors)
for tensor in tensors:
tensor[0][0][0][0][0] = 1.0
tensor[0][0][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0
# print(tensors)
layers_ = block.to_list()
assert layers is not layers_
assert len(layers) == len(layers_)
tensors_ = [torch.from_dlpack(bl) for bl in layers_]
for tensor, tensor_ in zip(tensors, tensors_):
assert tensor is not tensor_
assert tensor.shape == tensor_.shape
assert tensor.dtype == tensor_.dtype
assert torch.allclose(tensor, tensor_)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_gpu_layer_access(block_manager: BlockManager):
block_list = block_manager.allocate_device_blocks_blocking(1)
block = block_list[0]
layers = block.to_list()
assert len(layers) == NUM_LAYER
tensors = [torch.from_dlpack(bl) for bl in layers]
for tensor in tensors:
assert tensor.get_device() == DEVICE_ID # GPU
assert tensor.shape == (1, 1, OUTER_DIM, PAGE_SIZE, INNER_DIM)
assert tensor.dtype == TORCH_DTYPE
# print(tensors)
for tensor in tensors:
tensor[0][0][0][0][0] = 1.0
tensor[0][0][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0
# print(tensors)
layers_ = block.to_list()
assert layers is not layers_
assert len(layers) == len(layers_)
tensors_ = [torch.from_dlpack(bl) for bl in layers_]
for tensor, tensor_ in zip(tensors, tensors_):
assert tensor is not tensor_
assert tensor.shape == tensor_.shape
assert tensor.dtype == tensor_.dtype
assert torch.allclose(tensor, tensor_)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_iteration(block_manager: BlockManager):
block = (await block_manager.allocate_host_blocks(1))[0]
# Test __len__()
assert len(block) == NUM_LAYER
# Test __getitem__()
for i in range(NUM_LAYER):
layer = block[i]
tensor = torch.from_dlpack(layer)
tensor[0][0][0][0][0] = 1.0 + i
# Test __iter__() and __next__()
idx = 1.0
for layer in block:
tensor = torch.from_dlpack(layer)
assert tensor[0][0][0][0][0] == idx
tensor[0][0][0][0][0] += 0.5
idx += 1.0
assert idx == 1.0 + NUM_LAYER
# Test __iter__() should reset current index
idx = 1.0
for layer in block:
tensor = torch.from_dlpack(layer)
assert tensor[0][0][0][0][0] == idx + 0.5
idx += 1.0
assert idx == 1.0 + NUM_LAYER
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_layer_copy_g1_g2(block_manager: BlockManager):
# Allocate device (G1) and host (G2) block
host_block = (await block_manager.allocate_host_blocks(1))[0]
device_block = (await block_manager.allocate_device_blocks(1))[0]
# Populate host block at layer level with unique values
host_layer_tensors = [torch.from_dlpack(bl) for bl in host_block]
for i in range(NUM_LAYER):
host_layer_tensor = host_layer_tensors[i]
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
host_layer_tensor[0][0][j][k][w] = (
i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
# Copy host block to device block after permuting
permute_dims = (0, 2, 4, 3, 1)
host_block_tensor_ = torch.from_dlpack(host_block).permute(*permute_dims)
device_block_tensor_ = torch.from_dlpack(device_block).permute(*permute_dims)
device_block_tensor_.copy_(host_block_tensor_)
# Assert device block is contiguous and updated in block manager at layer level
device_layer_tensors = [torch.from_dlpack(bl) for bl in device_block]
for i in range(NUM_LAYER):
device_layer_tensor = device_layer_tensors[i]
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
assert (
device_layer_tensor[0][0][j][k][w]
== i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
# Set host block to zero and assert updated in block manager
host_block_tensor = torch.from_dlpack(host_block)
host_block_tensor.zero_()
assert torch.all(host_block_tensor_ == 0)
# Copy device block back to host block
host_block_tensor_.copy_(device_block_tensor_)
# Assert host block is updated in block manager
for i in range(NUM_LAYER):
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
assert (
host_block_tensor[0][i][j][k][w]
== i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
async def main():
await test_block_manager_initialization()
await test_cpu_block_access(new_block_manager())
await test_gpu_block_access(new_block_manager())
await test_block_list_iteration(new_block_manager())
await test_block_copy_g1_g2(new_block_manager())
await test_cpu_layer_access(new_block_manager())
await test_gpu_layer_access(new_block_manager())
await test_block_iteration(new_block_manager())
await test_block_layer_copy_g1_g2(new_block_manager())
if __name__ == "__main__":
asyncio.run(main())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import torch
try:
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import Request
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
VLLM_NOT_AVAILABLE = False
except ImportError:
VLLM_NOT_AVAILABLE = True
try:
from dynamo.llm import BlockManager
from dynamo.llm.vllm_integration.kv_cache_manager import KvbmCacheManager
KVBM_NOT_AVAILABLE = False
except ImportError:
KVBM_NOT_AVAILABLE = True
def new_kv_cache_manager(num_blocks: int = 11, page_size: int = 16):
"""
Creates a new KVBM cache manager.
Returns:
KvbmCacheManager: The KVBM cache manager.
"""
return KvbmCacheManager(
BlockManager(
worker_id=0,
leader=None,
page_size=page_size,
device_num_blocks=num_blocks,
)
)
def make_request(
request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None,
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs),
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
return KVCacheConfig(
num_blocks=num_blocks,
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
)
],
)
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
def test_prefill():
"""
Tests the KvbmCacheManager's prefill functionality.
"""
manager = new_kv_cache_manager()
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
# Step 1: Initial allocation - no computed blocks yet
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
# Step 2: Allocate slots for the request
blocks_req0 = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks) * 16, computed_blocks
)
for block in blocks_req0.blocks:
assert block._block_hash is None
# Verify allocation was successful
block_ids = manager.get_block_ids(req0.request_id)
assert len(block_ids) == 1 # One sequence in the request
assert len(block_ids[0]) == 4 # 4 blocks allocated (3 complete + 1 partial)
# Step 3: Simulate model execution by updating the request's computed tokens
req0.append_output_token_ids(100)
req0.num_computed_tokens = 55
_ = manager.allocate_slots(req0, num_new_tokens=1)
# Step 5: Create a new request with the same prefix plus one token
unique_token_ids = [3] * 4
req1 = make_request("1", common_token_ids + unique_token_ids)
# Step 8: Check for computed blocks - should find the common prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks) == 3
assert num_computed_tokens == len(computed_blocks.blocks) * 16
for block in computed_blocks.blocks:
assert block._block_hash is not None
# Clean up
del computed_blocks
manager.free_block_hashes(req0)
manager.free_block_hashes(req1)
# Cache miss and eviction.
req3 = make_request("3", [24] * (16 * 11))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks_req3 = manager.allocate_slots(
req3, 16 * 11, len(computed_blocks.blocks) * 16, computed_blocks
)
assert len(blocks_req3.blocks) == 11
for block, expected_block_id in zip(
blocks_req3.blocks, [4, 5, 6, 7, 8, 9, 10, 3, 2, 1, 0]
):
assert block._block_hash is None
assert block.block_id == expected_block_id
@pytest.mark.skip(reason="KVBM needs to support reset_prefix_cache")
def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests.
1. Schedule plp request and validate APC block allocation
2. Schedule non-plp request and validate blocks
3. Schedule plp request; no hit should occur; validate blocks
"""
manager = new_kv_cache_manager()
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
# Request #0 is a prompt logprobs request
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks) * 16, computed_blocks
)
# assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == [[0, 1, 2, 3]]
req0_block_hashes = [b.block_hash for b in blocks.blocks]
# Step 3: Simulate model execution by updating the request's computed tokens
req0.append_output_token_ids(100)
req0.num_computed_tokens = 55
_ = manager.allocate_slots(req0, num_new_tokens=1)
# Check full block metadata
"""
parent_block_hash = None
for block_id in (1, 2, 3):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
# Check partial block metadata
for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
"""
# Request #1 is a non-prompt-logprobs request:
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# assert len(manager.req_to_block_hashes[req1.request_id]) == 3
# assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == [[0, 1, 2]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(
req1, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks
)
# assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == [[4]]
# for block in computed_blocks.blocks:
# assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left.
# assert manager.block_pool.free_block_queue.num_free_blocks == 5
manager.free(req0)
manager.free(req1)
"""
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (6, 7, 8, 9, 10)]
# [unique_req0 (4)]
# [unique_req1 (5)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
"""
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
# assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req2, 55, len(computed_blocks.blocks) * 16, computed_blocks
)
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert block_ids != [[1, 2, 3, 4]]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for block_id in block_ids[0]:
assert manager.block_pool.blocks[block_id].ref_cnt == 1
manager.free(req2)
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
def test_decode():
manager = new_kv_cache_manager()
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks) * 16, computed_blocks
)
# assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == [[0, 1, 2, 3]]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(
req0, 4, len(computed_blocks.blocks) * 16, computed_blocks
)
assert new_blocks is not None and len(new_blocks.blocks) == 0
# NOTE(): There's no way to access the current active non-registered block
# from the python bindings.
# assert manager.single_type_manager.req_to_blocks[
# req0.request_id][-1].block_hash is None
# Append slots with allocating a new block.
req0.num_computed_tokens = 59
# 9 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(9 + 10):
req0.append_output_token_ids(7)
print(len(computed_blocks.blocks))
new_blocks = manager.allocate_slots(
req0, 19, len(computed_blocks.blocks) * 16, computed_blocks
)
assert new_blocks is not None and len(new_blocks.blocks) == 1
assert new_blocks.blocks[-1].block_hash is None
req0.num_computed_tokens = 78
req0.append_output_token_ids(100)
# The following is required for KVBM to register the block with id=3
_ = manager.allocate_slots(
req0, 1, len(computed_blocks.blocks) * 16, computed_blocks
)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# assert manager.single_type_manager.req_to_blocks[
# req0.request_id][-2].block_hash is not None
# assert manager.single_type_manager.req_to_blocks[
# req0.request_id][-1].block_hash is None
assert computed_blocks.blocks[-1].block_id == 3
assert computed_blocks.blocks[-1].block_hash is not None
# Clean up
manager.free_block_hashes(req0)
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
def test_evict():
manager = new_kv_cache_manager()
used_blocks = set()
last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 5 * 16 + 7, len(computed_blocks.blocks) * 16, computed_blocks
)
assert len(blocks.blocks) == 6 # 5 full + 1 partial
used_blocks.update(blocks.get_block_ids()[0])
req0.append_output_token_ids(100)
req0.num_computed_tokens = 5 * 16 + 7
manager.allocate_slots(req0, 1, len(computed_blocks.blocks) * 16, computed_blocks)
req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16 - 1)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req1, 3 * 16, len(computed_blocks.blocks) * 16, computed_blocks
)
assert (
len(blocks.blocks) == 3
) # 2 full blocks and 1 partial (15 tokens) 1 more will be added during allocate_slots
last_token_id += 3 * 16 - 1
used_blocks.update(blocks.get_block_ids()[0])
# 10 - (6 + 3) == 1
assert len(used_blocks) == 6 + 3
req1.append_output_token_ids(100)
req1.num_computed_tokens = 3 * 16 - 1
blocks = manager.allocate_slots(
req1, 1, len(computed_blocks.blocks) * 16, computed_blocks
)
manager.free(req0)
manager.free(req1)
# Can't access the free blocks queue from the python bindings.
# assert manager.block_pool.free_block_queue.num_free_blocks == 10
# assert [
# b.block_id
# for b in manager.block_pool.free_block_queue.get_all_free_blocks()
# ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
# assert computed_blocks.get_block_ids() == [[1, 2]]
assert computed_blocks.get_block_ids() == [[0, 1]]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(
req2, 3, len(computed_blocks.blocks) * 16, computed_blocks
)
assert blocks.get_block_ids() == [[9]]
# Can't access the free blocks queue from the python bindings.
# assert manager.block_pool.free_block_queue.num_free_blocks == 7
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
def test_hash_block_correct_reuse():
"""
This tests when a previously cached block is reused as a new block,
its hash metadata should be correctly reset.
"""
block_size = 16
manager = new_kv_cache_manager(num_blocks=2)
# Allocate 1 block and cache it.
num_tokens = block_size
req = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req, num_tokens, len(computed_blocks.blocks) * 16, computed_blocks
)
assert len(blocks.blocks) == 1
for t in range(5):
req.append_output_token_ids(100)
req.num_computed_tokens = num_tokens
blocks = manager.allocate_slots(
req, 5, len(computed_blocks.blocks) * 16, computed_blocks
)
computed_blocks, _ = manager.get_computed_blocks(req)
assert computed_blocks.blocks[0].block_hash is not None
assert computed_blocks.blocks[0].block_id == 0
# Deallocate the block.
del computed_blocks
manager.free(req)
# Allocate new blocks, last one is partial not full, make sure hash info on the
# blocks are cleared.
# KVBM will allocate block 1 first, then block 0. Need to verify,
# that block's 0 hash is cleared
req = make_request("1", list(range(256, 256 + 2 * num_tokens - 1)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req, 2 * num_tokens - 1, len(computed_blocks.blocks) * 16, computed_blocks
)
assert len(blocks.blocks) == 2
assert blocks.blocks[1].block_id == 0
assert blocks.blocks[1].block_hash is None
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
def test_computed_blocks_not_evicted():
"""
Test that the computed blocks are not evicted when getting new blocks
for a request if there are any other free blocks.
"""
block_size = 16
manager = new_kv_cache_manager(num_blocks=3)
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, num_tokens, len(computed_blocks.blocks) * 16, computed_blocks
)
assert len(blocks.blocks) == 1
# assert blocks.blocks[0].block_id == 1
assert blocks.blocks[0].block_id == 0
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req1, num_tokens, len(computed_blocks.blocks) * 16, computed_blocks
)
assert len(blocks.blocks) == 1
# assert blocks.blocks[0].block_id == 2
assert blocks.blocks[0].block_id == 1
# Need to simulate the forward pass to get blocks registered
req0.append_output_token_ids(100)
req0.num_computed_tokens = num_tokens
_ = manager.allocate_slots(
req0, 1, len(computed_blocks.blocks) * 16, computed_blocks
)
req1.append_output_token_ids(100)
req1.num_computed_tokens = num_tokens
_ = manager.allocate_slots(
req1, 1, len(computed_blocks.blocks) * 16, computed_blocks
)
# Free the blocks.
manager.free(req0)
manager.free(req1)
del computed_blocks
# Now if we have a cache hit on the block_id 0, we should evict the block_id 1
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks) == 1
# assert computed_blocks.blocks[0].block_id == 1
assert computed_blocks.blocks[0].block_id == 0
assert num_computed_tokens == block_size
# Allocate should return a free block with id 2 first, and then block with id 1
# which was evicted.
blocks = manager.allocate_slots(
req2,
num_tokens * 3 - num_computed_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks,
)
assert len(blocks.blocks) == 2
assert blocks.blocks[0].block_id == 2
assert blocks.blocks[1].block_id == 1
def _test_basic_prefix_caching_disabled():
"""
Currently, KVBM does not support `enable_caching` or setting it to False to disable prefix caching.
"""
pass
# @pytest.mark.parametrize("hash_fn", [sha256, hash])
def _test_cache_blocks(hash_fn):
"""
Hashing is done by KVBM and tested by the core library.
"""
pass
def _test_mm_prefix_caching():
"""
KVBM currently does not support multi-modal prefix caching.
This tests that the multi-modal prefix caching is correct.
"""
pass
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
def test_cache_key_salting():
"""
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
The test is mostly the same as the one for vLLM's native KV cache manager.
The only difference is for KVBM we don't need a `BlockHashType` object on python
side, thus we don't check the value of the salt. We test the salt-ing
functionality by validating cache miss and cache hit with different salts.
"""
block_size = 16
manager = new_kv_cache_manager()
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids = [i for i in range(3) for _ in range(block_size)]
token_ids = common_token_ids + [3] * 11
req0 = make_request("0", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks.blocks
assert num_computed_tokens == 0
"""
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt1", )
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None
"""
blocks = manager.allocate_slots(
req0, 59, len(computed_blocks.blocks) * 16, computed_blocks
)
assert blocks.get_block_ids() == [[0, 1, 2, 3]] # [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(
req0, 5, len(computed_blocks.blocks) * 16, computed_blocks
)
assert new_blocks is not None and len(new_blocks.blocks) == 0
print(new_blocks)
"""
# Now one more block that should not have extra keys.
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys is None
"""
# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
req1 = make_request("1", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks.blocks) == 3
assert num_computed_tokens == 3 * block_size
# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks) == 0
assert num_computed_tokens == 0
"""
block_hashes = manager.req_to_block_hashes[req2.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt2", )
"""
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
def test_prefill_not_enough_free_blocks_with_computed_blocks():
"""
This is a unit test that tests the correctness of the allocate_slots
when there is not enough free blocks. Specifically, when a request
has computed blocks but cannot be allocated due to not enough free blocks,
the computed blocks should not be touched.
"""
block_size = 16
manager = new_kv_cache_manager()
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, len(computed_blocks.blocks) * 16, computed_blocks)
# block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
block_part0 = len(manager.get_block_ids(req0.request_id)[0])
# Simulate model execution by updating the request's computed tokens
req0.append_output_token_ids(100)
req0.num_computed_tokens = 48
_ = manager.allocate_slots(req0, num_new_tokens=1)
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2) # Double the common tokens
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert (
len(computed_blocks.blocks) == block_part0
) # First 3 blocks are computed from req0
assert num_computed_tokens == 3 * 16 # 3 blocks * 16 tokens per block
manager.allocate_slots(req1, 48, num_computed_tokens, computed_blocks)
# block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
block_part1 = len(manager.get_block_ids(req1.request_id)[0])
# Simulate forward pass for req1 to compute all 6 blocks
req1.append_output_token_ids(100)
req1.num_computed_tokens = 96
_ = manager.allocate_slots(req1, num_new_tokens=1)
# Free req1 to make its blocks available
del computed_blocks
manager.free(req1)
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(
req2, block_size * 2, len(computed_blocks.blocks) * 16, computed_blocks
)
# Req3 is Req2 + 6 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
req3 = make_request("3", common_token_ids * 3) # Use same tokens as req1
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert len(computed_blocks.blocks) == block_part1 # Should find 6 computed blocks
assert num_computed_tokens == 6 * 16 # 6 blocks * 16 tokens per block
# Req3 cannot be allocated due to insufficient free blocks
# DYN LOG print:
# DEBUG dynamo_llm::block_manager::pool::state: not enough blocks available, requested: 3, available: 2
assert (
manager.allocate_slots(
req3, 48, len(computed_blocks.blocks) * 16, computed_blocks
)
is None
)
# Clean up
manager.free_block_hashes(req0)
manager.free_block_hashes(req2)
manager.free_block_hashes(req3)
def _test_reset_prefix_cache():
"""
`reset_prefix_cache` is currently not implemented.
It returns False every time it is called
"""
pass
def _test_prefix_cache_stats_disabled():
"""
`reset_prefix_cache` is currently not implemented.
It returns False every time it is called
"""
pass
# @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
def _test_kv_cache_events(blocks_to_cache: int):
"""
KVBM's Event Manager is responsible for emitting events.
Currently tested separately as a part of dynamo integration tests.
"""
pass
def _test_eagle_enabled_removes_last_block():
"""NOTE: KVBM does not support spec decoding at the moment.
Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
pass
def _test_eagle_with_partial_blocks():
"""NOTE: KVBM does not support spec decoding at the moment.
Test Eagle behavior with requests containing partial blocks."""
pass
def _test_eagle_with_sliding_window():
"""NOTE: KVBM does not support spec decoding at the moment.
Test Eagle behavior with sliding window."""
pass
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
def test_kvbm_wrong_blocks_provided():
"""
Tests that providing wrong blocks to allocate_slots results in an error.
Specifically, we test that using blocks from one request for another request
with different tokens should fail.
"""
manager = new_kv_cache_manager()
# Create two requests with different token patterns
req0 = make_request("0", [i for i in range(48)]) # 3 blocks of sequential tokens
req1 = make_request("1", [i * 2 for i in range(48)]) # 3 blocks of even tokens
# Allocate and compute blocks for req0
computed_blocks_req0, _ = manager.get_computed_blocks(req0)
_ = manager.allocate_slots(req0, 48, 0, computed_blocks_req0)
# Simulate forward pass
req0.append_output_token_ids(100) # Add output token
req0.num_computed_tokens = 48 # Mark all input tokens as computed
_ = manager.allocate_slots(req0, num_new_tokens=1) # Allocate slot for output token
# Try to use req0's blocks for req1 - this should fail
with pytest.raises(Exception) as exc_info:
manager.allocate_slots(req1, 48, 48, computed_blocks_req0)
assert (
"slot error: Insufficient capacity: need 48 tokens but only 0 available in mutable blocks"
in str(exc_info.value)
)
# Get computed blocks after forward pass
computed_blocks_req0, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(computed_blocks_req0.blocks) == 3 # Should have 3 complete blocks
assert num_computed_tokens == 48 # All input tokens should be computed
# Try to use req0's blocks for req1 - this should fail
with pytest.raises(Exception) as exc_info:
manager.allocate_slots(req1, 48, 48, computed_blocks_req0)
assert "slot error: computed block sequence hash mismatch" in str(exc_info.value)
# Clean up
manager.free_block_hashes(req0)
manager.free_block_hashes(req1)
@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available")
@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available")
@patch("dynamo.llm.vllm_integration.kv_cache_manager.KvbmCacheManager")
def test_kvbm_new_matched_tokens_edge_case(MockCacheManager):
PAGE_SIZE = 4
NUM_BLOCKS = 3
SEQ_LEN = PAGE_SIZE * NUM_BLOCKS
def create_list_mock(num_blocks: Optional[int]):
if num_blocks is None:
return None
mock_list = MagicMock()
mock_list.block_count.return_value = num_blocks
mock_list.__len__.return_value = num_blocks
return mock_list
def create_mock(num_host_blocks: Optional[int], num_disk_blocks: Optional[int]):
mock_instance = MagicMock()
mock_instance.block_size = PAGE_SIZE
mock_instance._create_slot.return_value = [0, 1, 2]
host = create_list_mock(num_host_blocks)
disk = create_list_mock(num_disk_blocks)
mock_instance.cache_manager.get_num_offloaded_computed_blocks.return_value = (
host,
disk,
)
return mock_instance
def get_pending_entry(mock, request_id):
(id, entry) = mock.pending_onboard_blocks.__setitem__.call_args[0]
assert id == request_id
return entry
def test_case(
num_host_blocks: Optional[int],
num_disk_blocks: Optional[int],
expected_num_external_computed_tokens: int,
):
request = make_request("0", [0] * SEQ_LEN)
mock = create_mock(num_host_blocks, num_disk_blocks)
(
num_external_computed_tokens,
async_load,
) = KvbmCacheManager.get_num_new_matched_tokens(mock, request, 0)
assert num_external_computed_tokens == expected_num_external_computed_tokens
assert not async_load
entry = get_pending_entry(mock, request.request_id)
assert (
entry[0] is None
if num_host_blocks is None
else len(entry[0]) == num_host_blocks
)
assert (
entry[1] is None
if num_disk_blocks is None
else len(entry[1]) == num_disk_blocks
)
# Case 1: Some blocks on host, no blocks on disk
test_case(2, None, 2 * PAGE_SIZE)
# Case 2: No blocks on host, some blocks on disk
test_case(None, 2, 2 * PAGE_SIZE)
# Case 3: All blocks on host.
test_case(3, None, SEQ_LEN - 1)
# Case 4: All blocks on disk.
test_case(None, 3, SEQ_LEN - 1)
...@@ -26,13 +26,13 @@ description = "Dynamo LLM Library" ...@@ -26,13 +26,13 @@ description = "Dynamo LLM Library"
[features] [features]
default = [] default = []
# todo(ops): get this working in CI as a default.
# todo: enable this as default
# default = ["block-manager", "testing-full"] # default = ["block-manager", "testing-full"]
testing-full = ["testing-cuda", "testing-nixl"] testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"] testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"] testing-nixl = ["dep:nixl-sys"]
testing-etcd = []
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"] block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"]
cuda = ["dep:cudarc"] cuda = ["dep:cudarc"]
integration = [] integration = []
...@@ -58,6 +58,7 @@ derive_builder = {workspace = true } ...@@ -58,6 +58,7 @@ derive_builder = {workspace = true }
either = { workspace = true } either = { workspace = true }
etcd-client = { workspace = true } etcd-client = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
futures-util = "0.3.31"
hf-hub = { workspace = true } hf-hub = { workspace = true }
humantime = { workspace = true } # input/batch humantime = { workspace = true } # input/batch
rand = { workspace = true } rand = { workspace = true }
...@@ -68,6 +69,7 @@ serde_json = { workspace = true } ...@@ -68,6 +69,7 @@ serde_json = { workspace = true }
strum = { workspace = true } strum = { workspace = true }
tempfile = { workspace = true } tempfile = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
tmq = "0.5.0"
tokio = { workspace = true } tokio = { workspace = true }
tokio-stream = { workspace = true } tokio-stream = { workspace = true }
tokio-util = { workspace = true } tokio-util = { workspace = true }
......
## Block States
<!-- Component Diagram - Table Sync -->
```mermaid
stateDiagram-v2
%% ─────────── State machine for mutable blocks ───────────
[*] --> Empty:::concrete %% initial pseudostate
Empty --> Partial:::concrete : initialize w\ salt hash
%% ── Partial: accepts tokens until full ──
Partial --> Partial : addTokens\n(space remains)
Partial --> ReadyForScheduling:::concrete : addTokens\n(space > 0)
%% ── Scheduling & compute phases ──
ReadyForScheduling --> Inflight:::concrete : scheduleCompute
ReadyForScheduling --> Partial : cancelSchedule
Inflight --> Partial : computeDone (not full)
Inflight --> Complete:::concrete : computeDone (full)
%% ── Finalisation ──
Complete --> Registered:::trait : register
%% ── External System Connections ──
Registered --> EventManager:::defaultConstructable : registerEvents
Registered --> OffloadManager:::defaultConstructable : offloadBlock
classDef concrete fill:#66B2B2,stroke:#2A4949,color:#1A2626
classDef trait fill:#B39DDB,stroke:#4A367A,color:#1A1426
classDef defaultConstructable fill:#E6C06E,stroke:#8B7355,color:#2B1810
```
Note: The color scheme is designed to be accessible in both light and dark modes, with:
- Teal representing concrete states in the block lifecycle (mutable blocks)
- Purple representing traits (immutable interface - Registered state)
- Muted gold representing default constructable components (external managers)
| State | Description |
|-------|-------------|
| Empty | Initial state before block initialization |
| Partial | State when block is partially filled with tokens |
| ReadyForScheduling | State when block is ready for compute scheduling |
| Inflight | State when block is being computed |
| Complete | State when block computation is complete |
| Registered | Final immutable state after block computation is finalized |
| EventManager | External system for managing block events (see separate diagram) |
| OffloadManager | External system for managing block offloading (see separate diagram) |
## OffloadManager
The OffloadManager orchestrates the movement of immutable registered blocks (Arc<MutableBlock>) between different memory hierarchies (e.g., GPU → CPU → SSD). It manages a pipeline of block transfers through three primary components:
1. **Transfer Engines**: Actively copies sequences of blocks between memory hierarchies. Optimized for transport bandwidth.
2. **On-Deck Stage**: Blocks are held in their shared immutable state (Arc<MutableBlock>), ready to be transferred next. This queue is filled first.
3. **In-Queue Stage**: A priority queue holding demoted weak references (Weak<MutableBlock>) to blocks. This queue is used if the On-Deck stage is full.
The system maintains a continuous flow: when Transfer Engines finish a set of transfers, prepared blocks are pulled from the On-Deck queue. Subsequently, In-Queue blocks are upgraded to strong references (Arc<MutableBlock>) and moved to the On-Deck queue. Weak blocks that cannot be upgraded are discarded, and new blocks are pulled from In-Queue until On-Deck is populated.
<!-- Component Diagram - Table Sync -->
```mermaid
stateDiagram-v2
direction LR
[*] --> InQueueWP:::weakRef : new block (weak ref)
InQueueWP --> OnDeckQ:::trait : upgrade weak ref
OnDeckQ --> TransferEng:::concrete : schedule transfer
TransferEng --> TransferredPS : transfer complete
TransferredPS --> [*]
%% Styling
classDef concrete fill:#66B2B2,stroke:#2A4949,color:#1A2626
classDef trait fill:#B39DDB,stroke:#4A367A,color:#1A1426
classDef defaultConstructable fill:#E6C06E,stroke:#8B7355,color:#2B1810
classDef weakRef fill:#D3D3D3,stroke:#808080,color:#333333
```
| Component | Description |
|-------------------|-----------------------------------------------------------------------------|
| InQueueWP | Priority queue of weak references (Weak<MutableBlock>) to blocks. |
| OnDeckQ | Queue of blocks in shared immutable state (Arc<MutableBlock>), ready for transfer. |
| TransferEng | Active transfer operations between memory hierarchies. |
| TransferredPS | Pseudo-state indicating blocks have been successfully transferred. |
<!-- Component Diagram - Table Sync -->
```mermaid
graph TD
subgraph "Memory Hierarchy"
direction LR
M_GPU[GPU Memory]:::concrete
M_CPU[CPU Memory]:::concrete
M_SSD[SSD Storage]:::concrete
end
subgraph "Offload Manager"
direction LR
IQ[In-Queue Weak Refs]:::weakRef
OD[On-Deck Arcs]:::trait
TE[Transfer Engines]:::concrete
end
%% Block Flow
NewBlock([New Immutable Block]) -.-> IQ
IQ -- upgrade viable --> OD
IQ -- discard unviable --> Discarded([X])
OD -- prepare batch --> TE
TE -- transfer to --> M_CPU
TE -- transfer to --> M_SSD
TE -- transfer to --> M_GPU
TE -- transfer complete --> TC([✓ Transferred])
%% Styling
classDef concrete fill:#66B2B2,stroke:#2A4949,color:#1A2626
classDef trait fill:#B39DDB,stroke:#4A367A,color:#1A1426
classDef defaultConstructable fill:#E6C06E,stroke:#8B7355,color:#2B1810
classDef weakRef fill:#D3D3D3,stroke:#808080,color:#333333
```
| Component | Description |
|----------------------------|---------------------------------------------------------------------------------|
| M_GPU | GPU Memory: Source memory hierarchy. |
| M_CPU | CPU Memory: Intermediate/Destination memory hierarchy. |
| M_SSD | SSD Storage: Destination memory hierarchy. |
| IQ In-Queue Weak Refs | Priority queue of weak references (Weak<MutableBlock>) to blocks awaiting offload. |
| OD (On-Deck Arcs) | Queue of shared immutable blocks (Arc<MutableBlock>) ready for transfer. |
| TE (Transfer Engines) | Manages the active copying of block data between memory locations. |
| NewBlock | Represents a new immutable block entering the offload system. |
| Discarded | Represents weak-referenced blocks that could not be upgraded and are discarded. |
| TC (Transferred) | Represents the state where a block transfer is successfully completed. |
Note: The color scheme is designed to be accessible in both light and dark modes, with:
- Teal (`concrete`): Concrete components, memory locations, and active processes.
- Purple (`trait`): Shared immutable blocks (Arc<T>).
- Muted Gold (`defaultConstructable`): Components that might be optionally constructed (not heavily used here).
- Light Gray (`weakRef`): Blocks held as weak references (Weak<T>).
...@@ -23,6 +23,8 @@ pub mod config; ...@@ -23,6 +23,8 @@ pub mod config;
mod state; mod state;
pub mod block; pub mod block;
pub mod connector;
pub mod distributed;
pub mod events; pub mod events;
pub mod layout; pub mod layout;
pub mod metrics; pub mod metrics;
...@@ -30,19 +32,20 @@ pub mod offload; ...@@ -30,19 +32,20 @@ pub mod offload;
pub mod pool; pub mod pool;
pub mod storage; pub mod storage;
// dynamo rt integration
pub mod controller;
pub use crate::common::dtype::DType; pub use crate::common::dtype::DType;
pub use block::{ pub use block::{
nixl::{ locality::{self, LocalityProvider, LogicalResources},
AsBlockDescriptorSet, BlockDescriptorList, IsImmutable, IsMutable, MutabilityKind, nixl::{BlockDescriptorList, IsImmutable, IsMutable, MutabilityKind, RemoteBlock},
RemoteBlock, BasicMetadata, BlockMetadata, Blocks, ImmutableBlock, MutableBlock,
},
transfer::{BlockTransferEngineV1, TransferRequestPut},
BasicMetadata, BlockMetadata, Blocks, ImmutableBlock,
}; };
pub use config::*; pub use config::*;
pub use layout::{nixl::NixlLayout, LayoutConfig, LayoutConfigBuilder, LayoutError, LayoutType}; pub use layout::{nixl::NixlLayout, LayoutConfig, LayoutConfigBuilder, LayoutError, LayoutType};
use offload::request::BlockResult; pub use offload::request::BlockResult;
pub use pool::BlockPool; pub use pool::{BlockPool, ManagedBlockPool};
pub use storage::{ pub use storage::{
nixl::NixlRegisterableStorage, DeviceStorage, DiskStorage, PinnedStorage, Storage, nixl::NixlRegisterableStorage, DeviceStorage, DiskStorage, PinnedStorage, Storage,
StorageAllocator, StorageAllocator,
...@@ -53,19 +56,21 @@ use anyhow::{Context, Result}; ...@@ -53,19 +56,21 @@ use anyhow::{Context, Result};
use block::nixl::{BlockMutability, NixlBlockSet, RemoteBlocks, SerializedNixlBlockSet}; use block::nixl::{BlockMutability, NixlBlockSet, RemoteBlocks, SerializedNixlBlockSet};
use derive_builder::Builder; use derive_builder::Builder;
use nixl_sys::Agent as NixlAgent; use nixl_sys::Agent as NixlAgent;
use serde::{Deserialize, Serialize};
use std::{ use std::{
collections::HashMap, collections::HashMap,
sync::{Arc, RwLock}, sync::{Arc, RwLock},
}; };
use storage::nixl::MemType; use storage::nixl::MemType;
use tokio::sync::oneshot;
use validator::Validate; use validator::Validate;
pub type WorkerID = u64; pub type WorkerID = u64;
pub type ReferenceBlockManager = KvBlockManager<BasicMetadata>; pub type ReferenceBlockManager = KvBlockManager<locality::Local, BasicMetadata>;
/// Represents the different cache levels for KV blocks /// Represents the different cache levels for KV blocks
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum CacheLevel { pub enum CacheLevel {
/// Represents KV blocks in GPU memory /// Represents KV blocks in GPU memory
G1, G1,
...@@ -80,6 +85,20 @@ pub enum CacheLevel { ...@@ -80,6 +85,20 @@ pub enum CacheLevel {
G4, G4,
} }
/// Type of channel used to reset the block manager to a specific cache level
pub type BlockResetChannel = tokio::sync::broadcast::Receiver<CacheLevel>;
#[derive(Debug)]
struct CancelOnLastDrop {
cancellation_token: CancellationToken,
}
impl Drop for CancelOnLastDrop {
fn drop(&mut self) {
self.cancellation_token.cancel();
}
}
// When we construct the pool: // When we construct the pool:
// 1. instantiate the runtime, // 1. instantiate the runtime,
// 2. build layout::LayoutConfigs for each of the requested storage types // 2. build layout::LayoutConfigs for each of the requested storage types
...@@ -87,33 +106,90 @@ pub enum CacheLevel { ...@@ -87,33 +106,90 @@ pub enum CacheLevel {
// 4. construct a Blocks object for each layout providing a unique block_set_idx // 4. construct a Blocks object for each layout providing a unique block_set_idx
// for each layout type. // for each layout type.
// 5. initialize the pools for each set of blocks // 5. initialize the pools for each set of blocks
pub struct KvBlockManager<Metadata: BlockMetadata> { #[derive(Debug)]
state: Arc<state::KvBlockManagerState<Metadata>>, pub struct KvBlockManager<Locality: LocalityProvider, Metadata: BlockMetadata> {
cancellation_token: CancellationToken, state: Arc<state::KvBlockManagerState<Locality, Metadata>>,
_cancellation_token: Arc<CancelOnLastDrop>,
block_size: usize,
}
impl<Locality: LocalityProvider, Metadata: BlockMetadata> Clone
for KvBlockManager<Locality, Metadata>
{
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
_cancellation_token: self._cancellation_token.clone(),
block_size: self.block_size,
}
}
}
impl<Locality: LocalityProvider, Metadata: BlockMetadata> KvBlockManager<Locality, Metadata> {
/// Get the block size
pub fn block_size(&self) -> usize {
self.block_size
}
/// Get a reference to the disk block pool
pub fn disk(&self) -> Option<&dyn BlockPool<DiskStorage, Locality, Metadata>> {
self.state.disk()
}
/// Get a reference to the host block pool
pub fn host(&self) -> Option<&dyn BlockPool<PinnedStorage, Locality, Metadata>> {
self.state.host()
}
/// Get a reference to the device block pool
pub fn device(&self) -> Option<&dyn BlockPool<DeviceStorage, Locality, Metadata>> {
self.state.device()
}
/// Get the worker ID
pub fn worker_id(&self) -> WorkerID {
self.state.worker_id()
}
/// Onboard a set of blocks to the device pool
pub fn onboard_blocks<S: Storage>(
&self,
blocks: Vec<ImmutableBlock<S, Locality, Metadata>>,
targets: Option<Vec<MutableBlock<DeviceStorage, Locality, Metadata>>>,
) -> oneshot::Receiver<BlockResult<DeviceStorage, Locality, Metadata>> {
self.state.onboard_blocks(blocks, targets)
}
} }
impl<Metadata: BlockMetadata> KvBlockManager<Metadata> { fn build_cancel_token(config: &mut KvBlockManagerConfig) -> Arc<CancelOnLastDrop> {
// The frontend of the KvBlockManager will take ownership of the cancellation token
// and will be responsible for cancelling the task when the KvBlockManager is dropped
let cancellation_token = config.runtime.cancellation_token.clone();
// The internal state will use a child token of the original token
config.runtime.cancellation_token = cancellation_token.child_token();
Arc::new(CancelOnLastDrop { cancellation_token })
}
impl<Metadata: BlockMetadata> KvBlockManager<locality::Local, Metadata> {
/// Create a new [KvBlockManager] /// Create a new [KvBlockManager]
/// ///
/// The returned object is a frontend to the [KvBlockManager] which owns the cancellation /// The returned object is a frontend to the [KvBlockManager] which owns the cancellation
/// tokens. When this object gets drop, the cancellation token will be cancelled and begin /// tokens. When this object gets drop, the cancellation token will be cancelled and begin
/// the gracefully shutdown of the block managers internal state. /// the gracefully shutdown of the block managers internal state.
pub fn new(config: KvBlockManagerConfig) -> Result<Self> { pub async fn new(mut config: KvBlockManagerConfig) -> Result<Self> {
let mut config = config; let _cancellation_token = build_cancel_token(&mut config);
// The frontend of the KvBlockManager will take ownership of the cancellation token
// and will be responsible for cancelling the task when the KvBlockManager is dropped
let cancellation_token = config.runtime.cancellation_token.clone();
// The internal state will use a child token of the original token let block_size = config.model.page_size;
config.runtime.cancellation_token = cancellation_token.child_token();
// Create the internal state // Create the internal state
let state = state::KvBlockManagerState::new(config)?; let state = state::KvBlockManagerState::<locality::Local, Metadata>::new(config).await?;
Ok(Self { Ok(Self {
state, state,
cancellation_token, _cancellation_token,
block_size,
}) })
} }
...@@ -145,53 +221,44 @@ impl<Metadata: BlockMetadata> KvBlockManager<Metadata> { ...@@ -145,53 +221,44 @@ impl<Metadata: BlockMetadata> KvBlockManager<Metadata> {
) -> Result<Vec<RemoteBlock<IsMutable>>> { ) -> Result<Vec<RemoteBlock<IsMutable>>> {
self.state.get_remote_blocks_mutable(bds) self.state.get_remote_blocks_mutable(bds)
} }
}
/// Get a reference to the disk block pool impl<R: LogicalResources, Metadata: BlockMetadata> KvBlockManager<locality::Logical<R>, Metadata> {
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> { pub async fn new(mut config: KvBlockManagerConfig, logical_resources: R) -> Result<Self> {
self.state.disk() let block_size = config.model.page_size;
}
/// Get a reference to the host block pool
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.state.host()
}
/// Get a reference to the device block pool let _cancellation_token = build_cancel_token(&mut config);
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.state.device()
}
/// Get the worker ID let state = state::KvBlockManagerState::<locality::Logical<R>, Metadata>::new(
pub fn worker_id(&self) -> WorkerID { config,
self.state.worker_id() logical_resources,
} )
.await?;
pub async fn onboard_blocks<S: Storage>( Ok(Self {
&self, state,
blocks: Vec<ImmutableBlock<S, Metadata>>, _cancellation_token,
) -> BlockResult<DeviceStorage, Metadata> { block_size,
self.state.onboard_blocks(blocks).await })
}
}
impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> {
fn drop(&mut self) {
self.cancellation_token.cancel();
} }
} }
#[cfg(all(test, feature = "testing-full"))] #[cfg(all(test, feature = "testing-full"))]
mod tests { mod tests {
use super::*; use super::*;
use crate::block_manager::block::BlockExt;
use crate::tokens::Tokens; use crate::tokens::Tokens;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
// Atomic Counter for Worker ID // Atomic Counter for Worker ID
static WORKER_ID: AtomicU64 = AtomicU64::new(1337); static WORKER_ID: AtomicU64 = AtomicU64::new(1337);
fn create_reference_block_manager() -> ReferenceBlockManager { pub fn create_reference_block_manager_config_with_counts(
device: usize,
host: usize,
disk: usize,
) -> KvBlockManagerConfig {
let worker_id = WORKER_ID.fetch_add(1, Ordering::SeqCst); let worker_id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
// Check if we're already in a Tokio runtime context // Check if we're already in a Tokio runtime context
...@@ -202,7 +269,7 @@ mod tests { ...@@ -202,7 +269,7 @@ mod tests {
Some(Arc::new(tokio::runtime::Runtime::new().unwrap())) Some(Arc::new(tokio::runtime::Runtime::new().unwrap()))
}; };
let config = KvBlockManagerConfig::builder() let builder = KvBlockManagerConfig::builder()
.runtime( .runtime(
KvManagerRuntimeConfig::builder() KvManagerRuntimeConfig::builder()
.worker_id(worker_id) .worker_id(worker_id)
...@@ -219,46 +286,73 @@ mod tests { ...@@ -219,46 +286,73 @@ mod tests {
.inner_dim(16) .inner_dim(16)
.build() .build()
.unwrap(), .unwrap(),
) );
.disk_layout(
let builder = if disk > 0 {
builder.disk_layout(
KvManagerLayoutConfig::builder() KvManagerLayoutConfig::builder()
.num_blocks(16) .num_blocks(disk)
.allocator(storage::DiskAllocator) .allocator(storage::DiskAllocator)
.build() .build()
.unwrap(), .unwrap(),
) )
.host_layout( } else {
builder
};
let builder = if host > 0 {
builder.host_layout(
KvManagerLayoutConfig::builder() KvManagerLayoutConfig::builder()
.num_blocks(16) .num_blocks(host)
.allocator(storage::PinnedAllocator::default()) .allocator(storage::PinnedAllocator::default())
.build() .build()
.unwrap(), .unwrap(),
) )
.device_layout( } else {
builder
};
let builder = if device > 0 {
builder.device_layout(
KvManagerLayoutConfig::builder() KvManagerLayoutConfig::builder()
.num_blocks(8) .num_blocks(device)
.allocator(storage::DeviceAllocator::new(0).unwrap()) .allocator(storage::DeviceAllocator::new(0).unwrap())
.build() .build()
.unwrap(), .unwrap(),
) )
.build() } else {
.unwrap(); builder
};
ReferenceBlockManager::new(config).unwrap() builder.build().unwrap()
} }
#[tokio::test] pub fn create_reference_block_manager_config() -> KvBlockManagerConfig {
async fn test_reference_block_manager_inherited_async_runtime() { create_reference_block_manager_config_with_counts(8, 16, 16)
dynamo_runtime::logging::init();
let _block_manager = create_reference_block_manager();
} }
// todo: solve the async runtime issue pub async fn create_reference_block_manager() -> ReferenceBlockManager {
#[ignore] ReferenceBlockManager::new(create_reference_block_manager_config())
#[test] .await
fn test_reference_block_manager_blocking() { .unwrap()
}
pub async fn create_reference_block_manager_with_counts(
device: usize,
host: usize,
disk: usize,
) -> ReferenceBlockManager {
ReferenceBlockManager::new(create_reference_block_manager_config_with_counts(
device, host, disk,
))
.await
.unwrap()
}
#[tokio::test]
async fn test_reference_block_manager_inherited_async_runtime() {
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
let _block_manager = create_reference_block_manager(); let _block_manager = create_reference_block_manager().await;
} }
// This tests mimics the behavior of two unique kvbm workers exchanging blocksets // This tests mimics the behavior of two unique kvbm workers exchanging blocksets
...@@ -267,13 +361,15 @@ mod tests { ...@@ -267,13 +361,15 @@ mod tests {
// //
// This test is meant to mimic the behavior of the basic nixl integration test found here: // This test is meant to mimic the behavior of the basic nixl integration test found here:
// https://github.com/ai-dynamo/nixl/blob/main/src/bindings/rust/src/tests.rs // https://github.com/ai-dynamo/nixl/blob/main/src/bindings/rust/src/tests.rs
// TODO: This test doesn't work because NIXL doesn't support partial metadata in the rust bindings.
#[ignore]
#[tokio::test] #[tokio::test]
async fn test_reference_block_managers() { async fn test_reference_block_managers() {
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
// create two block managers - mimics two unique dynamo workers // create two block managers - mimics two unique dynamo workers
let kvbm_0 = create_reference_block_manager(); let kvbm_0 = create_reference_block_manager().await;
let kvbm_1 = create_reference_block_manager(); let kvbm_1 = create_reference_block_manager().await;
assert_ne!(kvbm_0.worker_id(), kvbm_1.worker_id()); assert_ne!(kvbm_0.worker_id(), kvbm_1.worker_id());
...@@ -287,16 +383,16 @@ mod tests { ...@@ -287,16 +383,16 @@ mod tests {
// Worker 0 // Worker 0
// Allocate 4 mutable blocks on the host // Allocate 4 mutable blocks on the host
let blocks_0 = kvbm_0.host().unwrap().allocate_blocks(4).await.unwrap(); let _blocks_0 = kvbm_0.host().unwrap().allocate_blocks(4).await.unwrap();
// Create a BlockDescriptorList for the mutable blocks // // Create a BlockDescriptorList for the mutable blocks
// let blockset_0 = BlockDescriptorList::from_mutable_blocks(&blocks_0).unwrap(); // // let blockset_0 = BlockDescriptorList::from_mutable_blocks(&blocks_0).unwrap();
let blockset_0 = blocks_0.as_block_descriptor_set().unwrap(); // let blockset_0 = blocks_0.as_block_descriptor_set().unwrap();
// Worker 1 // // Worker 1
// Create a RemoteBlock list from blockset_0 // // Create a RemoteBlock list from blockset_0
let _blocks_1 = kvbm_1.host().unwrap().allocate_blocks(4).await.unwrap(); // let _blocks_1 = kvbm_1.host().unwrap().allocate_blocks(4).await.unwrap();
let mut _remote_blocks_0 = kvbm_1.get_remote_blocks_mutable(&blockset_0).unwrap(); // let mut _remote_blocks_0 = kvbm_1.get_remote_blocks_mutable(&blockset_0).unwrap();
// TODO(#967) - Enable with TransferEngine // TODO(#967) - Enable with TransferEngine
...@@ -339,7 +435,7 @@ mod tests { ...@@ -339,7 +435,7 @@ mod tests {
async fn test_offload() -> Result<()> { async fn test_offload() -> Result<()> {
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
let block_manager = create_reference_block_manager(); let block_manager = create_reference_block_manager().await;
let device = block_manager.device().unwrap(); let device = block_manager.device().unwrap();
...@@ -359,7 +455,7 @@ mod tests { ...@@ -359,7 +455,7 @@ mod tests {
let host_blocks = block_manager let host_blocks = block_manager
.host() .host()
.unwrap() .unwrap()
.match_sequence_hashes(vec![immutable_device_blocks[0].sequence_hash()?].as_slice()) .match_sequence_hashes(vec![immutable_device_blocks[0].sequence_hash()].as_slice())
.await .await
.unwrap(); .unwrap();
assert_eq!(host_blocks.len(), 1); assert_eq!(host_blocks.len(), 1);
...@@ -367,7 +463,7 @@ mod tests { ...@@ -367,7 +463,7 @@ mod tests {
let disk_blocks = block_manager let disk_blocks = block_manager
.disk() .disk()
.unwrap() .unwrap()
.match_sequence_hashes(vec![immutable_device_blocks[0].sequence_hash()?].as_slice()) .match_sequence_hashes(vec![immutable_device_blocks[0].sequence_hash()].as_slice())
.await .await
.unwrap(); .unwrap();
assert_eq!(disk_blocks.len(), 1); assert_eq!(disk_blocks.len(), 1);
......
...@@ -13,27 +13,29 @@ ...@@ -13,27 +13,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod factory;
pub mod locality;
pub mod data;
pub mod registry; pub mod registry;
pub mod state; pub mod state;
pub mod transfer; pub mod transfer;
pub mod view;
pub use data::{view, BlockData, BlockDataExt, BlockDataProvider, BlockDataProviderMut};
pub use locality::LocalityProvider;
pub use crate::tokens::TokenBlockError; pub use crate::tokens::TokenBlockError;
pub use anyhow::Result; pub use anyhow::Result;
use nixl_sys::NixlDescriptor;
pub use registry::{GlobalRegistry, RegistrationHandle}; pub use registry::{GlobalRegistry, RegistrationHandle};
pub use state::{BlockState, BlockStateInvalid}; pub use state::{BlockState, BlockStateInvalid};
pub use transfer::TransferContext;
use crate::block_manager::{ use crate::block_manager::{
state::KvBlockManagerState as BlockManager, state::KvBlockManagerState as BlockManager,
storage::{Local, Remote, Storage}, storage::{Local, Remote, Storage, StorageTypeProvider},
}; };
use crate::tokens::{SaltHash, SequenceHash, Token, TokenBlock, Tokens}; use crate::tokens::{SaltHash, SequenceHash, Token, TokenBlock, Tokens};
use transfer::{Immutable, Mutable, Readable, Writable};
use super::{ use super::{
events::PublishHandle, events::PublishHandle,
layout::{BlockLayout, LayoutError, LayoutType}, layout::{BlockLayout, LayoutError, LayoutType},
...@@ -49,7 +51,8 @@ use std::{ ...@@ -49,7 +51,8 @@ use std::{
}; };
use thiserror::Error; use thiserror::Error;
mod private { pub mod private {
#[derive(Clone, Copy)]
pub struct PrivateToken; pub struct PrivateToken;
} }
...@@ -71,8 +74,23 @@ pub enum BlockError { ...@@ -71,8 +74,23 @@ pub enum BlockError {
#[error("Invalid state: {0}")] #[error("Invalid state: {0}")]
InvalidState(String), InvalidState(String),
#[error("Invalid block ID: {0}")]
InvalidBlockID(BlockId),
#[error("Misconfigured block data parallelism: {0}")]
MisconfiguredBlockDataParallelism(String),
#[error("Incompatible storage type: {0}")]
IncompatibleStorageType(String),
#[error("Views are not available on logical blocks")]
ViewsNotAvailableOnLogicalBlocks,
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
#[error("Immutable block already has a duplicate")]
IncompatibleImmutableBlock,
} }
pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync + 'static { pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync + 'static {
...@@ -91,23 +109,28 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync + ...@@ -91,23 +109,28 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync +
fn offload_priority(&self) -> Option<u64>; fn offload_priority(&self) -> Option<u64>;
} }
/// Marker trait for types that are mutable blocks /// A trait for blocks that can be returned to the pool.
pub trait WritableBlock: BlockDataProviderMut { ///
type StorageType: Storage + NixlDescriptor; /// This is used to determine if a block can be dropped when it is returned to the pool.
/// If the block is droppable, it will be returned to the pool.
fn storage_type_id(&self) -> std::any::TypeId { /// If the block is not droppable, it will be kept alive until the pool is reset.
std::any::TypeId::of::<<Self as WritableBlock>::StorageType>() pub trait MaybeReturnableBlock<S: Storage, L: LocalityProvider, M: BlockMetadata> {
} /// At the time of the call, the block is singularly owned and therefore will be returned to the pool
/// if dropped.
fn is_returnable(&self) -> bool;
/// Try to take ownership of the block.
///
/// This is an internal function guarded by the PrivateToken and is used to implement the public facing
/// [`super::pool::BlockPool::return_block`] and [`super::pool::BlockPool::return_block_blocking`] functions.
fn try_take_block(self, token: private::PrivateToken) -> Option<Vec<Block<S, L, M>>>;
} }
/// Marker trait for types that are immutable blocks /// Marker trait for types that are mutable blocks
pub trait ReadableBlock: BlockDataProvider { pub trait WritableBlock: BlockDataProviderMut {}
type StorageType: Storage + NixlDescriptor;
fn storage_type_id(&self) -> std::any::TypeId { /// Marker trait for types that are immutable blocks
std::any::TypeId::of::<<Self as ReadableBlock>::StorageType>() pub trait ReadableBlock: BlockDataProvider {}
}
}
pub trait ReadableBlocks {} pub trait ReadableBlocks {}
...@@ -132,42 +155,54 @@ pub trait AsBlockMutSlice<'a, B: 'a> { ...@@ -132,42 +155,54 @@ pub trait AsBlockMutSlice<'a, B: 'a> {
} }
/// Blanket trait for anything that can be converted into a mutable block /// Blanket trait for anything that can be converted into a mutable block
pub trait IntoWritableBlocks<M: BlockMetadata> { pub trait IntoWritableBlocks<Locality: LocalityProvider, M: BlockMetadata> {
type Output: WritableBlocks; type Output: WritableBlocks;
fn into_writable_blocks(self, manager: &BlockManager<M>) -> BlockResult<Self::Output>; fn into_writable_blocks(self, manager: &BlockManager<Locality, M>)
-> BlockResult<Self::Output>;
} }
impl<T: WritableBlocks, M: BlockMetadata> IntoWritableBlocks<M> for T { impl<T: WritableBlocks, Locality: LocalityProvider, M: BlockMetadata>
IntoWritableBlocks<Locality, M> for T
{
type Output = T; type Output = T;
fn into_writable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> { fn into_writable_blocks(
self,
_manager: &BlockManager<Locality, M>,
) -> BlockResult<Self::Output> {
Ok(self) Ok(self)
} }
} }
pub trait IntoReadableBlocks<M: BlockMetadata> { pub trait IntoReadableBlocks<Locality: LocalityProvider, M: BlockMetadata> {
type Output: ReadableBlocks; type Output: ReadableBlocks;
fn into_readable_blocks(self, manager: &BlockManager<M>) -> BlockResult<Self::Output>; fn into_readable_blocks(self, manager: &BlockManager<Locality, M>)
-> BlockResult<Self::Output>;
} }
impl<T: ReadableBlocks, M: BlockMetadata> IntoReadableBlocks<M> for T { impl<T: ReadableBlocks, Locality: LocalityProvider, M: BlockMetadata>
IntoReadableBlocks<Locality, M> for T
{
type Output = T; type Output = T;
fn into_readable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> { fn into_readable_blocks(
self,
_manager: &BlockManager<Locality, M>,
) -> BlockResult<Self::Output> {
Ok(self) Ok(self)
} }
} }
/// A block with storage and associated metadata/state /// A block with storage and associated metadata/state
#[derive(Debug)] #[derive(Debug)]
pub struct Block<S: Storage, M: BlockMetadata> { pub struct Block<S: Storage, L: LocalityProvider, M: BlockMetadata> {
data: BlockData<S>, data: L::BlockData<S>,
metadata: M, metadata: M,
state: BlockState, state: BlockState,
manager: Option<Arc<BlockManager<M>>>, manager: Option<Arc<BlockManager<L, M>>>,
} }
impl<S: Storage, M: BlockMetadata> Block<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Block<S, L, M> {
/// Create a new block with default metadata/state /// Create a new block with default metadata/state
pub fn new(data: BlockData<S>, metadata: M) -> BlockResult<Self> { pub fn new(data: L::BlockData<S>, metadata: M) -> BlockResult<Self> {
Ok(Self { Ok(Self {
data, data,
metadata, metadata,
...@@ -196,16 +231,108 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> { ...@@ -196,16 +231,108 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
} }
} }
pub(crate) fn reset(&mut self) { /// Reset the state of the block (public method replacing old crate-only version)
pub fn reset(&mut self) {
self.state = BlockState::Reset; self.state = BlockState::Reset;
self.metadata.reset_metadata(); self.metadata.reset_metadata();
} }
pub(crate) fn set_manager(&mut self, manager: Arc<BlockManager<M>>) { /// Initialize a sequence on the block using a [SaltHash]
///
/// The block must be in the [BlockState::Reset] state.
///
/// After initialization, the block will be in the [BlockState::Partial] state.
pub fn init_sequence(&mut self, salt_hash: SaltHash) -> Result<()> {
Ok(self
.state
.initialize_sequence(self.page_size(), salt_hash)?)
}
/// Appends a single token to the block if it is in the Partial state and not full.
/// Returns `Err` if the block is not Partial or already full.
pub fn add_token(&mut self, token: Token) -> Result<()> {
self.state.add_token(token)
}
/// Appends multiple tokens to the block if it is in the Partial state
/// and has enough remaining capacity for *all* provided tokens.
/// The block must be in the [BlockState::Partial] state.
/// Returns `Err` if the block is not Partial or if there isn't enough space.
pub fn add_tokens(&mut self, tokens: Tokens) -> Result<Tokens> {
self.state.add_tokens(tokens)
}
/// Removes the last token from the block.
/// Requires the block to be in the Partial state and not empty.
/// Returns `Err` otherwise.
pub fn pop_token(&mut self) -> Result<()> {
self.state.pop_token()
}
/// Removes the last `count` tokens from the block.
/// Requires the block to be in the Partial state and have at least `count` tokens.
/// Returns `Err` otherwise.
pub fn pop_tokens(&mut self, count: usize) -> Result<()> {
self.state.pop_tokens(count)
}
/// Commit the block
/// Requires the block to be in the [BlockState::Partial] state and completely full.
/// Transitions the state to [BlockState::Complete]. Returns `Err` otherwise.
pub fn commit(&mut self) -> Result<()> {
self.state.commit()
}
/// Apply a [TokenBlock] to the block
/// Requires the block to be in the [BlockState::Reset] state.
///
/// Additionally, the [TokenBlock] must match the [BlockLayout::page_size()]
/// Transitions the state to [BlockState::Complete]. Returns `Err` otherwise.
pub fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()> {
if self.page_size() != token_block.tokens().len() {
return Err(BlockStateInvalid(format!(
"TokenBlock size ({}) does not match Block page size ({})",
token_block.tokens().len(),
self.page_size()
))
.into());
}
self.state.apply_token_block(token_block)
}
/// Returns the number of tokens currently in the block.
pub fn len(&self) -> usize {
match self.state.len() {
Some(len) => len,
None => self.page_size(),
}
}
/// Returns the number of additional tokens that can be added (only valid for Partial state).
pub fn remaining(&self) -> usize {
self.state.remaining()
}
/// Returns true if the block contains no tokens (only true for Reset or empty Partial state).
pub fn is_empty(&self) -> bool {
self.state.is_empty()
}
/// Returns true if the block is full.
pub fn is_full(&self) -> bool {
self.len() == self.page_size()
}
/// Returns a list of tokens in the block.
pub fn tokens(&self) -> Option<&Tokens> {
self.state.tokens()
}
pub(crate) fn set_manager(&mut self, manager: Arc<BlockManager<L, M>>) {
self.manager = Some(manager); self.manager = Some(manager);
} }
pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<M>>> { pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<L, M>>> {
self.manager.as_ref() self.manager.as_ref()
} }
...@@ -230,24 +357,41 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> { ...@@ -230,24 +357,41 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
&self.state &self.state
} }
/// Get a mutable reference to the state of the block
pub fn state_mut(&mut self) -> &mut BlockState {
&mut self.state
}
/// Get the number of blocks in the block /// Get the number of blocks in the block
/// todo(ryan): validate this can be removed
pub fn num_blocks(&self) -> usize { pub fn num_blocks(&self) -> usize {
1 1
} }
/// Get the block ID of the block
pub fn block_id(&self) -> BlockId {
self.data.block_id()
}
/// Get the number of layers in the block /// Get the number of layers in the block
pub fn num_layers(&self) -> usize { pub fn num_layers(&self) -> usize {
self.data.layout.num_layers() self.data.num_layers()
} }
/// Get the size of each block in the block /// Get the size of each block in the block
pub fn page_size(&self) -> usize { pub fn page_size(&self) -> usize {
self.data.layout.page_size() self.data.page_size()
} }
/// Get the inner dimension of the block /// Get the inner dimension of the block
pub fn inner_dim(&self) -> usize { pub fn inner_dim(&self) -> usize {
self.data.layout.inner_dim() self.data.num_inner_dims()
}
/// Get the number of outer dimensions in this block
/// Works for all localities through BlockLayoutConfig
pub fn num_outer_dims(&self) -> usize {
self.data.num_outer_dims()
} }
pub(crate) fn metadata_on_acquired(&mut self, tick: u64) { pub(crate) fn metadata_on_acquired(&mut self, tick: u64) {
...@@ -266,7 +410,7 @@ pub(crate) trait PrivateBlockExt { ...@@ -266,7 +410,7 @@ pub(crate) trait PrivateBlockExt {
) -> Result<Option<PublishHandle>, registry::BlockRegistrationError>; ) -> Result<Option<PublishHandle>, registry::BlockRegistrationError>;
} }
impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> PrivateBlockExt for Block<S, L, M> {
fn register( fn register(
&mut self, &mut self,
registry: &mut registry::BlockRegistry, registry: &mut registry::BlockRegistry,
...@@ -275,6 +419,28 @@ impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> { ...@@ -275,6 +419,28 @@ impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> {
} }
} }
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Local for Block<S, L, M> {}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> StorageTypeProvider for Block<S, L, M> {
type StorageType = S;
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProvider for Block<S, L, M> {
type Locality = L;
fn block_data(&self) -> &impl BlockDataExt<S> {
&self.data
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProviderMut for Block<S, L, M> {
type Locality = L;
fn block_data_mut(&mut self) -> &mut impl BlockDataExt<S> {
&mut self.data
}
}
pub trait BlockExt { pub trait BlockExt {
/// Reset the state of the block /// Reset the state of the block
fn reset(&mut self); fn reset(&mut self);
...@@ -334,204 +500,6 @@ pub trait BlockExt { ...@@ -334,204 +500,6 @@ pub trait BlockExt {
fn tokens(&self) -> Option<&Tokens>; fn tokens(&self) -> Option<&Tokens>;
} }
impl<S: Storage, M: BlockMetadata> BlockExt for Block<S, M> {
fn reset(&mut self) {
Block::reset(self);
}
fn init_sequence(&mut self, salt_hash: SaltHash) -> Result<()> {
Ok(self
.state
.initialize_sequence(self.page_size(), salt_hash)?)
}
fn add_token(&mut self, token: Token) -> Result<()> {
self.state.add_token(token)
}
fn add_tokens(&mut self, tokens: Tokens) -> Result<Tokens> {
self.state.add_tokens(tokens)
}
fn pop_token(&mut self) -> Result<()> {
self.state.pop_token()
}
fn pop_tokens(&mut self, count: usize) -> Result<()> {
self.state.pop_tokens(count)
}
fn commit(&mut self) -> Result<()> {
self.state.commit()
}
fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()> {
if self.page_size() != token_block.tokens().len() {
return Err(BlockStateInvalid(format!(
"TokenBlock size ({}) does not match Block page size ({})",
token_block.tokens().len(),
self.page_size()
))
.into());
}
self.state.apply_token_block(token_block)
}
fn len(&self) -> usize {
match self.state.len() {
Some(len) => len,
None => self.page_size(),
}
}
fn remaining(&self) -> usize {
self.state.remaining()
}
fn is_empty(&self) -> bool {
self.state.is_empty()
}
fn is_full(&self) -> bool {
self.len() == self.page_size()
}
fn tokens(&self) -> Option<&Tokens> {
self.state.tokens()
}
}
pub trait BlockDataExt<S: Storage + NixlDescriptor> {
/// Returns true if the block data is fully contiguous
fn is_fully_contiguous(&self) -> bool;
/// Returns the number of layers in the block
fn num_layers(&self) -> usize;
/// Returns the number of outer dimensions in the block
fn num_outer_dims(&self) -> usize;
/// Get a read-only view of this block's storage for a layer
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>>;
/// Get a mutable view of this block's storage for a layer
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>>;
/// Get a read-only view of this block's storage
fn block_view(&self) -> BlockResult<view::BlockView<S>>;
/// Get a mutable view of this block's storage
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>>;
}
/// Individual block storage - cannot be cloned to ensure uniqueness
#[derive(Debug)]
pub struct BlockData<S: Storage> {
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_idx: usize,
block_set_idx: usize,
worker_id: WorkerID,
}
impl<S> BlockData<S>
where
S: Storage,
{
/// Create a new block storage
pub(crate) fn new(
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_idx: usize,
block_set_idx: usize,
worker_id: WorkerID,
) -> Self {
Self {
layout,
block_idx,
block_set_idx,
worker_id,
}
}
pub fn storage_type(&self) -> StorageType {
self.layout.storage_type()
}
}
impl<S: Storage + NixlDescriptor> BlockDataExt<S> for BlockData<S>
where
S: Storage + NixlDescriptor,
{
fn is_fully_contiguous(&self) -> bool {
self.layout.layout_type() == LayoutType::FullyContiguous
}
fn num_layers(&self) -> usize {
self.layout.num_layers()
}
fn num_outer_dims(&self) -> usize {
self.layout.outer_dim()
}
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
let mr = self
.layout
.memory_region(self.block_idx, layer_idx, outer_idx)?;
unsafe { view::LayerView::new(self, mr.addr(), mr.size()) }
}
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>> {
let mr = self
.layout
.memory_region(self.block_idx, layer_idx, outer_idx)?;
unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size()) }
}
fn block_view(&self) -> BlockResult<view::BlockView<S>> {
if self.is_fully_contiguous() {
let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
let offset = mr.addr();
let size = mr.size() * self.num_layers();
unsafe { view::BlockView::new(self, offset, size) }
} else {
Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(),
))
}
}
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
if self.is_fully_contiguous() {
let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
let offset = mr.addr();
let size = mr.size() * self.num_layers();
unsafe { view::BlockViewMut::new(self, offset, size) }
} else {
Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(),
))
}
}
}
pub trait BlockDataProvider {
type StorageType: Storage + NixlDescriptor;
fn block_data(&self, _: private::PrivateToken) -> &BlockData<Self::StorageType>;
}
pub trait BlockDataProviderMut: BlockDataProvider {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<Self::StorageType>;
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Getters)] #[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Getters)]
pub struct BasicMetadata { pub struct BasicMetadata {
#[getter(copy)] #[getter(copy)]
...@@ -592,7 +560,7 @@ impl<L: BlockLayout + 'static, M: BlockMetadata> Blocks<L, M> { ...@@ -592,7 +560,7 @@ impl<L: BlockLayout + 'static, M: BlockMetadata> Blocks<L, M> {
} }
/// Convert collection into Vec<Block> with default metadata/state /// Convert collection into Vec<Block> with default metadata/state
pub fn into_blocks(self) -> BlockResult<Vec<Block<L::StorageType, M>>> { pub fn into_blocks(self) -> BlockResult<Vec<Block<L::StorageType, locality::Local, M>>> {
// convert box to arc // convert box to arc
let layout: Arc<dyn BlockLayout<StorageType = L::StorageType>> = Arc::new(*self.layout); let layout: Arc<dyn BlockLayout<StorageType = L::StorageType>> = Arc::new(*self.layout);
layout_to_blocks(layout, self.block_set_idx, self.worker_id) layout_to_blocks(layout, self.block_set_idx, self.worker_id)
...@@ -603,38 +571,59 @@ pub(crate) fn layout_to_blocks<S: Storage, M: BlockMetadata>( ...@@ -603,38 +571,59 @@ pub(crate) fn layout_to_blocks<S: Storage, M: BlockMetadata>(
layout: Arc<dyn BlockLayout<StorageType = S>>, layout: Arc<dyn BlockLayout<StorageType = S>>,
block_set_idx: usize, block_set_idx: usize,
worker_id: WorkerID, worker_id: WorkerID,
) -> BlockResult<Vec<Block<S, M>>> { ) -> BlockResult<Vec<Block<S, locality::Local, M>>> {
(0..layout.num_blocks()) (0..layout.num_blocks())
.map(|idx| { .map(|idx| {
let data = BlockData::new(layout.clone(), idx, block_set_idx, worker_id); let data = BlockData::new(layout.clone(), idx, block_set_idx, worker_id);
let data = data;
Block::new(data, M::default()) Block::new(data, M::default())
}) })
.collect() .collect()
} }
pub struct MutableBlock<S: Storage, M: BlockMetadata> { pub struct MutableBlock<S: Storage, L: LocalityProvider, M: BlockMetadata> {
block: Option<Block<S, M>>, block: Option<Block<S, L, M>>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>, return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
// Use to track parent relationship, as well as ensure that parents of registered blocks stay // Use to track parent relationship, as well as ensure that parents of registered blocks stay
// alive as long as the child is alive. // alive as long as the child is alive.
parent: Option<Arc<MutableBlock<S, M>>>, parent: Option<Arc<MutableBlock<S, L, M>>>,
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> WritableBlock for MutableBlock<S, M> { // MutableBlock inherits identification methods from Block via Deref
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> StorageTypeProvider
for MutableBlock<S, L, M>
{
type StorageType = S; type StorageType = S;
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> ReadableBlock for MutableBlock<S, M> {
type StorageType = S; impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProvider
for MutableBlock<S, L, M>
{
type Locality = L;
fn block_data(&self) -> &impl BlockDataExt<S> {
&self.block.as_ref().expect("block was dropped").data
}
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Writable for MutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Readable for MutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Mutable for MutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Local for MutableBlock<S, M> {}
impl<S: Storage, M: BlockMetadata> MutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProviderMut
for MutableBlock<S, L, M>
{
type Locality = L;
fn block_data_mut(&mut self) -> &mut impl BlockDataExt<S> {
&mut self.block.as_mut().expect("block was dropped").data
}
}
// Marker trait implementations for MutableBlock
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Local for MutableBlock<S, L, M> {}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MutableBlock<S, L, M> {
pub(crate) fn new( pub(crate) fn new(
block: Block<S, M>, block: Block<S, L, M>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>, return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
) -> Self { ) -> Self {
Self { Self {
block: Some(block), block: Some(block),
...@@ -643,19 +632,31 @@ impl<S: Storage, M: BlockMetadata> MutableBlock<S, M> { ...@@ -643,19 +632,31 @@ impl<S: Storage, M: BlockMetadata> MutableBlock<S, M> {
} }
} }
pub fn set_parent(&mut self, parent: Arc<MutableBlock<S, M>>) { pub fn set_parent(&mut self, parent: Arc<MutableBlock<S, L, M>>) {
self.parent = Some(parent); self.parent = Some(parent);
} }
} }
impl<S: Storage, M: BlockMetadata> std::fmt::Debug for MutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> std::fmt::Debug for MutableBlock<S, L, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MutableBlock {{ block: {:?} }}", self.block) match &self.block {
Some(block) => {
write!(
f,
"MutableBlock(storage_type: {:?}, block_id: {}, sequence_hash: {:?})",
block.block_data().storage_type(),
block.block_id(),
block.sequence_hash().ok()
)
}
None => write!(f, "MutableBlock(block: None)"),
}
} }
} }
impl<S: Storage, M: BlockMetadata> Drop for MutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Drop for MutableBlock<S, L, M> {
fn drop(&mut self) { fn drop(&mut self) {
tracing::debug!("drop: {:?}", self);
if let Some(block) = self.block.take() { if let Some(block) = self.block.take() {
if self.return_tx.send(block).is_err() { if self.return_tx.send(block).is_err() {
tracing::warn!("block pool shutdown before block was returned"); tracing::warn!("block pool shutdown before block was returned");
...@@ -664,227 +665,245 @@ impl<S: Storage, M: BlockMetadata> Drop for MutableBlock<S, M> { ...@@ -664,227 +665,245 @@ impl<S: Storage, M: BlockMetadata> Drop for MutableBlock<S, M> {
} }
} }
impl<S: Storage, M: BlockMetadata> Deref for MutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Deref for MutableBlock<S, L, M> {
type Target = Block<S, M>; type Target = Block<S, L, M>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.block.as_ref().expect("block was dropped") self.block.as_ref().expect("block was dropped")
} }
} }
impl<S: Storage, M: BlockMetadata> DerefMut for MutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> DerefMut for MutableBlock<S, L, M> {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
self.block.as_mut().expect("block was dropped") self.block.as_mut().expect("block was dropped")
} }
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for MutableBlock<S, M> { // MutableBlock provides access to block data through simpler methods
fn is_fully_contiguous(&self) -> bool { // Simplified MutableBlock API - direct delegation to underlying data
self.data.is_fully_contiguous() // MutableBlock inherits methods from Block via Deref - no need for separate implementations
}
// // Local-specific BlockDataProvider implementations
fn num_layers(&self) -> usize { // impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataProvider
self.data.num_layers() // for MutableBlock<S, locality::Local, M>
} // {
// type StorageType = S;
fn num_outer_dims(&self) -> usize {
self.data.num_outer_dims() // fn block_data(&self, _: private::PrivateToken) -> &BlockData<S> {
} // &self.block.as_ref().expect("block was dropped").data
// }
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> { // }
self.data.layer_view(layer_idx, outer_idx)
} // impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataProviderMut
// for MutableBlock<S, locality::Local, M>
fn layer_view_mut( // {
&mut self, // fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<S> {
layer_idx: usize, // &mut self.block.as_mut().expect("block was dropped").data
outer_idx: usize, // }
) -> BlockResult<view::LayerViewMut<S>> { // }
self.data.layer_view_mut(layer_idx, outer_idx)
} impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
AsBlockSlice<'a, MutableBlock<S, L, M>> for [MutableBlock<S, L, M>]
fn block_view(&self) -> BlockResult<view::BlockView<S>> {
self.data.block_view()
}
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
self.data.block_view_mut()
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataProvider for MutableBlock<S, M> {
type StorageType = S;
fn block_data(&self, _: private::PrivateToken) -> &BlockData<S> {
&self.block.as_ref().expect("block was dropped").data
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataProviderMut for MutableBlock<S, M> {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<S> {
&mut self.block.as_mut().expect("block was dropped").data
}
}
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockSlice<'a, MutableBlock<S, M>>
for [MutableBlock<S, M>]
{ {
fn as_block_slice(&'a self) -> &'a [MutableBlock<S, M>] { fn as_block_slice(&'a self) -> &'a [MutableBlock<S, L, M>] {
self self
} }
} }
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockSlice<'a, MutableBlock<S, M>> impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
for Vec<MutableBlock<S, M>> AsBlockSlice<'a, MutableBlock<S, L, M>> for Vec<MutableBlock<S, L, M>>
{ {
fn as_block_slice(&'a self) -> &'a [MutableBlock<S, M>] { fn as_block_slice(&'a self) -> &'a [MutableBlock<S, L, M>] {
self.as_slice() self.as_slice()
} }
} }
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockMutSlice<'a, MutableBlock<S, M>> impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
for [MutableBlock<S, M>] AsBlockMutSlice<'a, MutableBlock<S, L, M>> for [MutableBlock<S, L, M>]
{ {
fn as_block_mut_slice(&'a mut self) -> &'a mut [MutableBlock<S, M>] { fn as_block_mut_slice(&'a mut self) -> &'a mut [MutableBlock<S, L, M>] {
self self
} }
} }
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockMutSlice<'a, MutableBlock<S, M>> impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
for Vec<MutableBlock<S, M>> AsBlockMutSlice<'a, MutableBlock<S, L, M>> for Vec<MutableBlock<S, L, M>>
{ {
fn as_block_mut_slice(&'a mut self) -> &'a mut [MutableBlock<S, M>] { fn as_block_mut_slice(&'a mut self) -> &'a mut [MutableBlock<S, L, M>] {
self.as_mut_slice() self.as_mut_slice()
} }
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> IntoWritableBlocks<M> for MutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> IntoWritableBlocks<L, M>
type Output = Vec<MutableBlock<S, M>>; for MutableBlock<S, L, M>
fn into_writable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> { {
type Output = Vec<MutableBlock<S, L, M>>;
fn into_writable_blocks(self, _manager: &BlockManager<L, M>) -> BlockResult<Self::Output> {
Ok(vec![self]) Ok(vec![self])
} }
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> IntoReadableBlocks<M> for MutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> IntoReadableBlocks<L, M>
type Output = Vec<MutableBlock<S, M>>; for MutableBlock<S, L, M>
fn into_readable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> { {
type Output = Vec<MutableBlock<S, L, M>>;
fn into_readable_blocks(self, _manager: &BlockManager<L, M>) -> BlockResult<Self::Output> {
Ok(vec![self]) Ok(vec![self])
} }
} }
#[derive(Debug)] impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MaybeReturnableBlock<S, L, M>
pub struct ImmutableBlock<S: Storage, M: BlockMetadata> { for MutableBlock<S, L, M>
block: Arc<MutableBlock<S, M>>, {
} fn is_returnable(&self) -> bool {
self.block.is_some()
}
impl<S: Storage, M: BlockMetadata> Clone for ImmutableBlock<S, M> { fn try_take_block(mut self, _: private::PrivateToken) -> Option<Vec<Block<S, L, M>>> {
fn clone(&self) -> Self { self.block.take().map(|block| vec![block])
Self {
block: self.block.clone(),
}
} }
} }
impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { pub struct ImmutableBlock<S: Storage, L: LocalityProvider, M: BlockMetadata> {
pub(crate) fn new(block: Arc<MutableBlock<S, M>>) -> Self { block: Arc<MutableBlock<S, L, M>>,
Self { block } sequence_hash: SequenceHash,
} duplicate: Option<Arc<MutableBlock<S, L, M>>>,
}
pub(crate) fn mutable_block(&self) -> &Arc<MutableBlock<S, M>> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> std::fmt::Debug
&self.block for ImmutableBlock<S, L, M>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ImmutableBlock(storage: {:?}, block_id: {}, sequence_hash: {})",
self.block
.block
.as_ref()
.expect("block was dropped")
.block_data()
.storage_type(),
self.block_id(),
self.sequence_hash
)
} }
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> ReadableBlock for ImmutableBlock<S, M> { // ImmutableBlock inherits identification methods from Block via Deref
type StorageType = S;
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Readable for ImmutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Immutable for ImmutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Local for ImmutableBlock<S, M> {}
impl<S: Storage, M: BlockMetadata> Deref for ImmutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Clone for ImmutableBlock<S, L, M> {
type Target = Block<S, M>; fn clone(&self) -> Self {
fn deref(&self) -> &Self::Target { Self {
self.block block: self.block.clone(),
.as_ref() sequence_hash: self.sequence_hash,
.block duplicate: self.duplicate.clone(),
.as_ref() }
.expect("block was dropped")
} }
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for ImmutableBlock<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ImmutableBlock<S, L, M> {
fn is_fully_contiguous(&self) -> bool { pub(crate) fn new(block: Arc<MutableBlock<S, L, M>>) -> Self {
self.block.is_fully_contiguous() let sequence_hash = block.sequence_hash().expect("block is in the wrong state");
Self {
block,
sequence_hash,
duplicate: None,
}
} }
fn num_layers(&self) -> usize { /// Attempts to add a duplicate block to the ImmutableBlock.
self.block.num_layers() pub(crate) fn with_duplicate(
self,
duplicate: Arc<MutableBlock<S, L, M>>,
) -> Result<Self, BlockError> {
if self.duplicate.is_some() {
return Err(BlockError::IncompatibleImmutableBlock);
}
Ok(Self {
duplicate: Some(duplicate),
..self
})
} }
fn num_outer_dims(&self) -> usize { pub(crate) fn mutable_block(&self) -> &Arc<MutableBlock<S, L, M>> {
self.block.num_outer_dims() &self.block
} }
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> { pub fn sequence_hash(&self) -> SequenceHash {
self.block.layer_view(layer_idx, outer_idx) self.sequence_hash
} }
fn layer_view_mut(&mut self, _: usize, _: usize) -> BlockResult<view::LayerViewMut<S>> { /// If the ImmutableBlock is a duplicate, returns the block ID of the duplicate;
// This should never be called since ImmutableBlock is immutable, /// otherwise, returns the block ID of the primary block.
// but we need to implement the full trait pub fn block_id(&self) -> BlockId {
Err(BlockError::InvalidState( self.duplicate
"Cannot get mutable layer view from immutable block".to_string(), .as_ref()
)) .map_or(self.block.block_id(), |duplicate| duplicate.block_id())
} }
fn block_view(&self) -> BlockResult<view::BlockView<S>> { /// Returns true if the ImmutableBlock holds a duplicate block.
self.block.block_view() #[allow(unused)]
pub(crate) fn is_duplicate(&self) -> bool {
self.duplicate.is_some()
} }
}
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> StorageTypeProvider
// This should never be called since ImmutableBlock is immutable, for ImmutableBlock<S, L, M>
// but we need to implement the full trait {
Err(BlockError::InvalidState( type StorageType = S;
"Cannot get mutable block view from immutable block".to_string(), }
))
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProvider
for ImmutableBlock<S, L, M>
{
type Locality = L;
fn block_data(&self) -> &impl BlockDataExt<S> {
&self.block.block.as_ref().expect("block was dropped").data
} }
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataProvider for ImmutableBlock<S, M> { // Marker trait implementations for ImmutableBlock
type StorageType = S; impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Local for ImmutableBlock<S, L, M> {}
fn block_data(&self, _: private::PrivateToken) -> &BlockData<S> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Deref for ImmutableBlock<S, L, M> {
&self type Target = Block<S, L, M>;
.block fn deref(&self) -> &Self::Target {
self.block
.as_ref() .as_ref()
.block .block
.as_ref() .as_ref()
.expect("block was dropped") .expect("block was dropped")
.data
} }
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> IntoReadableBlocks<M> for ImmutableBlock<S, M> { // ImmutableBlock provides access to block data through simpler methods
type Output = Vec<ImmutableBlock<S, M>>; // Simplified block API - direct delegation to underlying data
fn into_readable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> { // ImmutableBlock inherits methods from Block via Deref - no need for separate implementations
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> IntoReadableBlocks<L, M>
for ImmutableBlock<S, L, M>
{
type Output = Vec<ImmutableBlock<S, L, M>>;
fn into_readable_blocks(self, _manager: &BlockManager<L, M>) -> BlockResult<Self::Output> {
Ok(vec![self]) Ok(vec![self])
} }
} }
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>> impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
for [ImmutableBlock<S, M>] AsBlockSlice<'a, ImmutableBlock<S, L, M>> for [ImmutableBlock<S, L, M>]
{ {
fn as_block_slice(&'a self) -> &'a [ImmutableBlock<S, M>] { fn as_block_slice(&'a self) -> &'a [ImmutableBlock<S, L, M>] {
self self
} }
} }
impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>> impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
for Vec<ImmutableBlock<S, M>> AsBlockSlice<'a, ImmutableBlock<S, L, M>> for Vec<ImmutableBlock<S, L, M>>
{ {
fn as_block_slice(&'a self) -> &'a [ImmutableBlock<S, M>] { fn as_block_slice(&'a self) -> &'a [ImmutableBlock<S, L, M>] {
self.as_slice() self.as_slice()
} }
} }
impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { impl<S: Storage + 'static, L: LocalityProvider, M: BlockMetadata> ImmutableBlock<S, L, M> {
pub async fn enqueue_offload(&self, priority: u64) -> Result<()> { pub async fn enqueue_offload(&self, priority: u64) -> Result<()> {
if let Some(manager) = self.manager() { if let Some(manager) = self.manager() {
manager.enqueue_offload_block(self, priority).await?; manager.enqueue_offload_block(self, priority).await?;
...@@ -895,6 +914,43 @@ impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { ...@@ -895,6 +914,43 @@ impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
} }
} }
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MaybeReturnableBlock<S, L, M>
for ImmutableBlock<S, L, M>
{
fn is_returnable(&self) -> bool {
// determine if the arc use count is 1; if duplicate, evaluate that arc, otherwise evaluate the primary
match &self.duplicate {
Some(duplicate) => Arc::strong_count(duplicate) == 1,
None => Arc::strong_count(&self.block) == 1,
}
}
fn try_take_block(mut self, token: private::PrivateToken) -> Option<Vec<Block<S, L, M>>> {
let blocks = [
Arc::try_unwrap(self.block).ok(),
self.duplicate
.take()
.and_then(|duplicate| Arc::try_unwrap(duplicate).ok()),
];
let blocks = blocks
.into_iter()
.flatten()
.filter_map(|block| block.try_take_block(token))
.flatten()
.collect::<Vec<_>>();
if blocks.is_empty() {
None
} else {
Some(blocks)
}
}
}
impl<B: BlockDataProvider> ReadableBlock for B {}
impl<B: BlockDataProviderMut> WritableBlock for B {}
pub mod nixl { pub mod nixl {
use super::*; use super::*;
...@@ -1005,6 +1061,7 @@ pub mod nixl { ...@@ -1005,6 +1061,7 @@ pub mod nixl {
} }
} }
// Comment out Nixl-related code for now
pub trait NixlBlockDataImmutable<S: Storage + NixlDescriptor>: BlockDataExt<S> { pub trait NixlBlockDataImmutable<S: Storage + NixlDescriptor>: BlockDataExt<S> {
/// Get the NIXL memory descriptor for the entire block /// Get the NIXL memory descriptor for the entire block
fn as_block_descriptor( fn as_block_descriptor(
...@@ -1019,22 +1076,6 @@ pub mod nixl { ...@@ -1019,22 +1076,6 @@ pub mod nixl {
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>>; ) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>>;
} }
pub trait NixlBlockDataMutable<S: Storage + NixlDescriptor>:
BlockDataExt<S> + NixlBlockDataImmutable<S>
{
/// Get the NIXL memory descriptor for the entire block
fn as_block_descriptor_mut(
&mut self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsMutable>>;
/// Get the NIXL memory descriptor for a specific layer
fn as_layer_descriptor_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>>;
}
impl<S: Storage + NixlDescriptor> NixlBlockDataImmutable<S> for BlockData<S> { impl<S: Storage + NixlDescriptor> NixlBlockDataImmutable<S> for BlockData<S> {
fn as_block_descriptor( fn as_block_descriptor(
&self, &self,
...@@ -1051,24 +1092,6 @@ pub mod nixl { ...@@ -1051,24 +1092,6 @@ pub mod nixl {
} }
} }
impl<S: Storage + NixlDescriptor> NixlBlockDataMutable<S> for BlockData<S> {
fn as_block_descriptor_mut(
&mut self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsMutable>> {
Ok(self.block_view_mut()?.as_nixl_descriptor_mut())
}
fn as_layer_descriptor_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> {
Ok(self
.layer_view_mut(layer_idx, outer_idx)?
.as_nixl_descriptor_mut())
}
}
/// Error type for NixlBlockSet serialization/deserialization failures. /// Error type for NixlBlockSet serialization/deserialization failures.
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum NixlSerializationError { pub enum NixlSerializationError {
...@@ -1231,13 +1254,13 @@ pub mod nixl { ...@@ -1231,13 +1254,13 @@ pub mod nixl {
impl<M: MutabilityKind> Remote for RemoteBlock<M> {} impl<M: MutabilityKind> Remote for RemoteBlock<M> {}
impl<M: MutabilityKind> ReadableBlock for RemoteBlock<M> { // impl<M: MutabilityKind> ReadableBlock for RemoteBlock<M> {
type StorageType = NixlStorage; // type StorageType = NixlStorage;
} // }
impl WritableBlock for RemoteBlock<IsMutable> { // impl WritableBlock for RemoteBlock<IsMutable> {
type StorageType = NixlStorage; // type StorageType = NixlStorage;
} // }
impl<M: MutabilityKind> RemoteBlock<M> { impl<M: MutabilityKind> RemoteBlock<M> {
pub fn new( pub fn new(
...@@ -1254,84 +1277,23 @@ pub mod nixl { ...@@ -1254,84 +1277,23 @@ pub mod nixl {
} }
} }
impl<M: MutabilityKind> BlockDataExt<NixlStorage> for RemoteBlock<M> { impl<M: MutabilityKind> StorageTypeProvider for RemoteBlock<M> {
fn is_fully_contiguous(&self) -> bool { type StorageType = NixlStorage;
self.data.is_fully_contiguous()
}
fn num_layers(&self) -> usize {
self.data.num_layers()
}
fn num_outer_dims(&self) -> usize {
self.data.num_outer_dims()
}
fn layer_view(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerView<NixlStorage>> {
self.data.layer_view(layer_idx, outer_idx)
}
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<NixlStorage>> {
self.data.layer_view_mut(layer_idx, outer_idx)
}
fn block_view(&self) -> BlockResult<view::BlockView<NixlStorage>> {
self.data.block_view()
}
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<NixlStorage>> {
self.data.block_view_mut()
}
} }
impl<M: MutabilityKind> BlockDataProvider for RemoteBlock<M> { impl<M: MutabilityKind> BlockDataProvider for RemoteBlock<M> {
type StorageType = NixlStorage; type Locality = locality::Local;
fn block_data(&self, _: private::PrivateToken) -> &BlockData<NixlStorage> { fn block_data(&self) -> &impl BlockDataExt<NixlStorage> {
&self.data &self.data
} }
} }
impl<M: MutabilityKind> NixlBlockDataImmutable<NixlStorage> for RemoteBlock<M> {
fn as_block_descriptor(
&self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsImmutable>> {
self.data.as_block_descriptor()
}
fn as_layer_descriptor(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> {
self.data.as_layer_descriptor(layer_idx, outer_idx)
}
}
impl BlockDataProviderMut for RemoteBlock<IsMutable> { impl BlockDataProviderMut for RemoteBlock<IsMutable> {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<NixlStorage> { type Locality = locality::Local;
&mut self.data
}
}
impl NixlBlockDataMutable<NixlStorage> for RemoteBlock<IsMutable> {
fn as_block_descriptor_mut(
&mut self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsMutable>> {
self.data.as_block_descriptor_mut()
}
fn as_layer_descriptor_mut( fn block_data_mut(&mut self) -> &mut impl BlockDataExt<NixlStorage> {
&mut self, &mut self.data
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> {
self.data.as_layer_descriptor_mut(layer_idx, outer_idx)
} }
} }
...@@ -1375,40 +1337,6 @@ pub mod nixl { ...@@ -1375,40 +1337,6 @@ pub mod nixl {
pub mutability: BlockMutability, pub mutability: BlockMutability,
} }
// Placeholder Trait: Real pool handles must provide this info.
// This trait allows BlockDescriptorList constructors to be generic.
pub trait BlockHandleInfo {
fn worker_id(&self) -> WorkerID; // Needs access to the parent KvBlockManager's ID
fn block_set_idx(&self) -> usize;
fn block_idx(&self) -> usize;
}
impl<S: Storage> BlockHandleInfo for BlockData<S> {
fn worker_id(&self) -> WorkerID {
self.worker_id
}
fn block_set_idx(&self) -> usize {
self.block_set_idx
}
fn block_idx(&self) -> usize {
self.block_idx
}
}
impl<S: Storage, M: BlockMetadata> BlockHandleInfo for Block<S, M> {
fn worker_id(&self) -> WorkerID {
self.data.worker_id
}
fn block_set_idx(&self) -> usize {
self.data.block_set_idx
}
fn block_idx(&self) -> usize {
self.data.block_idx
}
}
/// A validated, homogeneous, and serializable collection of BlockDescriptors. /// A validated, homogeneous, and serializable collection of BlockDescriptors.
/// Primarily used to describe sets of remote blocks for transfer operations. /// Primarily used to describe sets of remote blocks for transfer operations.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Getters)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Getters)]
...@@ -1427,13 +1355,6 @@ pub mod nixl { ...@@ -1427,13 +1355,6 @@ pub mod nixl {
// derived from block_set_idx via the NixlBlockSet on the receiving side. // derived from block_set_idx via the NixlBlockSet on the receiving side.
} }
impl<M: BlockMetadata> IntoWritableBlocks<M> for BlockDescriptorList {
type Output = Vec<RemoteBlock<IsMutable>>;
fn into_writable_blocks(self, manager: &BlockManager<M>) -> BlockResult<Self::Output> {
Ok(manager.get_remote_blocks_mutable(&self)?)
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum BlockDescriptorSetError { pub enum BlockDescriptorSetError {
#[error("Input block list cannot be empty")] #[error("Input block list cannot be empty")]
...@@ -1451,165 +1372,21 @@ pub mod nixl { ...@@ -1451,165 +1372,21 @@ pub mod nixl {
)] )]
InvalidBlockHandle, InvalidBlockHandle,
} }
impl BlockDescriptorList {
/// Creates a new validated BlockDescriptorList from a slice of block handles.
/// Ensures all handles belong to the same worker and block set.
fn new<S: Storage>(
blocks: &[&BlockData<S>], // Use the generic trait bound
mutability: BlockMutability,
) -> Result<Self, BlockDescriptorSetError> {
if blocks.is_empty() {
return Err(BlockDescriptorSetError::EmptyInput);
}
let first = blocks[0];
let worker_id = first.worker_id();
let block_set_idx = first.block_set_idx();
let mut block_indices = Vec::with_capacity(blocks.len());
block_indices.push(first.block_idx());
for block in blocks.iter().skip(1) {
// Validate homogeneity
if block.worker_id() != worker_id || block.block_set_idx() != block_set_idx {
return Err(BlockDescriptorSetError::NotHomogeneous);
}
block_indices.push(block.block_idx());
}
// TODO: Potentially validate MemType derived from block_set_idx here if possible
Ok(Self {
worker_id,
block_set_idx,
mutability,
block_indices,
})
}
/// Creates a BlockDescriptorList representing immutable blocks.
pub fn from_immutable_blocks<S: Storage, M: BlockMetadata>(
blocks: &[ImmutableBlock<S, M>],
) -> Result<Self, BlockDescriptorSetError> {
// Map each block handle to Option<&BlockData>,
// then convert Option to Result (treating None as an error),
// finally collect into Result<Vec<&BlockData>, Error>.
let data: Vec<&BlockData<S>> = blocks
.iter()
.map(|b| b.block.block.as_ref().map(|inner_b| &inner_b.data))
.map(|opt| opt.ok_or(BlockDescriptorSetError::InvalidBlockHandle))
.collect::<Result<Vec<&BlockData<S>>, _>>()?;
Self::new(&data, BlockMutability::Immutable)
}
/// Creates a BlockDescriptorList representing mutable blocks.
pub fn from_mutable_blocks<S: Storage, M: BlockMetadata>(
blocks: &[MutableBlock<S, M>],
) -> Result<Self, BlockDescriptorSetError> {
// Map each block handle to Option<&BlockData>,
// then convert Option to Result (treating None as an error),
// finally collect into Result<Vec<&BlockData>, Error>.
let data: Vec<&BlockData<S>> = blocks
.iter()
.map(|b| b.block.as_ref().map(|inner_b| &inner_b.data))
.map(|opt| opt.ok_or(BlockDescriptorSetError::InvalidBlockHandle))
.collect::<Result<Vec<&BlockData<S>>, _>>()?;
Self::new(&data, BlockMutability::Mutable)
}
// /// Serializes the BlockDescriptorList into a byte vector.
// pub fn serialize(&self) -> Result<Vec<u8>, BlockDescriptorSetError> {
// Ok(serde_json::to_vec(self)?)
// }
// /// Deserializes a BlockDescriptorList from a byte slice.
// pub fn deserialize(data: &[u8]) -> Result<Self, BlockDescriptorSetError> {
// Ok(serde_json::from_slice(data)?)
// }
}
pub trait AsBlockDescriptorSet {
type Block;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError>;
}
impl<S, M> AsBlockDescriptorSet for [ImmutableBlock<S, M>]
where
S: Storage,
M: BlockMetadata,
{
type Block = ImmutableBlock<S, M>;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError> {
BlockDescriptorList::from_immutable_blocks(self)
}
}
impl<S, M> AsBlockDescriptorSet for [MutableBlock<S, M>]
where
S: Storage,
M: BlockMetadata,
{
type Block = MutableBlock<S, M>;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError> {
BlockDescriptorList::from_mutable_blocks(self)
}
}
impl<T> AsBlockDescriptorSet for Vec<T>
where
[T]: AsBlockDescriptorSet<Block = T>,
{
type Block = T;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError> {
self.as_slice().as_block_descriptor_set()
}
}
impl<T, const N: usize> AsBlockDescriptorSet for [T; N]
where
[T]: AsBlockDescriptorSet<Block = T>,
{
type Block = T;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError> {
self.as_slice().as_block_descriptor_set()
}
}
}
#[cfg(test)]
pub mod test_utils {
use super::private::PrivateToken;
pub fn get_private_token() -> PrivateToken {
PrivateToken
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use super::nixl::*; use super::super::layout::tests::setup_layout;
use super::super::layout::{
nixl::{NixlLayout, SerializedNixlBlockLayout, ToSerializedNixlBlockLayout},
tests::setup_layout,
FullyContiguous, LayoutConfig,
};
use crate::block_manager::storage::SystemAllocator;
use crate::tokens::TokenBlockSequence;
use dynamo_runtime::logging::init as init_logging; use crate::tokens::{TokenBlockSequence, Tokens};
use nixl_sys::Agent as NixlAgent;
const BLOCK_SIZE: u32 = 4; const BLOCK_SIZE: u32 = 4;
const SALT_HASH: SaltHash = 12345; const SALT_HASH: SaltHash = 12345;
// Helper to create a default reset block // Helper to create a default reset block
fn create_reset_block() -> Block<impl Storage, BasicMetadata> { fn create_reset_block() -> Block<impl Storage, locality::Local, BasicMetadata> {
let layout = setup_layout(None).unwrap(); let layout = setup_layout(None).unwrap();
let data = BlockData::new(Arc::new(layout), 0, 42, 0); let data = BlockData::new(Arc::new(layout), 0, 42, 0);
Block::new(data, BasicMetadata::default()).unwrap() Block::new(data, BasicMetadata::default()).unwrap()
...@@ -1813,170 +1590,177 @@ mod tests { ...@@ -1813,170 +1590,177 @@ mod tests {
); );
} }
#[test] // #[test]
fn test_nixl_block_data_ext() { // fn test_nixl_block_data_ext() {
init_logging(); // init_logging();
let config = LayoutConfig::builder() // let config = LayoutConfig::builder()
.num_blocks(10) // .num_blocks(10)
.num_layers(3) // .num_layers(3)
.outer_dim(2) // .outer_dim(2)
.page_size(4) // .page_size(4)
.inner_dim(13) // .inner_dim(13)
.build() // .build()
.unwrap(); // .unwrap();
let mut layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap(); // let mut layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap();
let agent = NixlAgent::new("test").unwrap(); // let agent = NixlAgent::new("test").unwrap();
tracing::info!("Registering layout"); // tracing::info!("Registering layout");
layout.nixl_register(&agent, None).unwrap(); // layout.nixl_register(&agent, None).unwrap();
tracing::info!("Layout registered"); // tracing::info!("Layout registered");
let serialized = layout.serialize().unwrap(); // let serialized = layout.serialize().unwrap();
let layout = Arc::new(layout); // let layout = Arc::new(layout);
let data = BlockData::new(layout.clone(), 0, 42, 0); // let data = BlockData::new(layout.clone(), 0, 42, 0);
assert_eq!(data.block_idx(), 0); // assert_eq!(data.block_id(), 0);
assert_eq!(data.block_set_idx(), 42); // assert_eq!(data.block_set_id(), 42);
let block_desc = data.as_block_descriptor().unwrap(); // let block_desc = data.as_block_descriptor().unwrap();
println!("Block descriptor: {:?}", block_desc); // println!("Block descriptor: {:?}", block_desc);
let data = BlockData::new(layout.clone(), 1, 42, 0); // let data = BlockData::new(layout.clone(), 1, 42, 0);
assert_eq!(data.block_idx(), 1); // assert_eq!(data.block_id(), 1);
assert_eq!(data.block_set_idx(), 42); // assert_eq!(data.block_set_id(), 42);
let block_desc = data.as_block_descriptor().unwrap(); // let block_desc = data.as_block_descriptor().unwrap();
println!("Block descriptor: {:?}", block_desc); // println!("Block descriptor: {:?}", block_desc);
let remote_layout = SerializedNixlBlockLayout::deserialize(&serialized).unwrap(); // let remote_layout = SerializedNixlBlockLayout::deserialize(&serialized).unwrap();
println!("Nixl layout: {:?}", remote_layout); // println!("Nixl layout: {:?}", remote_layout);
let remote_block = RemoteBlock::<IsMutable>::new(remote_layout.clone(), 0, 42, 0); // let remote_block = RemoteBlock::<IsMutable>::new(remote_layout.clone(), 0, 42, 0);
let remote_desc = remote_block.as_block_descriptor().unwrap(); // let remote_desc = remote_block.as_block_descriptor().unwrap();
println!("Remote Descriptor: {:?}", remote_desc); // println!("Remote Descriptor: {:?}", remote_desc);
// drop(layout); // // drop(layout);
tracing::info!("Layout dropped"); // tracing::info!("Layout dropped");
} // }
#[test] // #[test]
fn test_mutable_block_data_ext() { // fn test_mutable_block_data_ext() {
init_logging(); // init_logging();
// Create a layout with multiple layers and blocks for testing all methods // // Create a layout with multiple layers and blocks for testing all methods
let config = LayoutConfig::builder() // let config = LayoutConfig::builder()
.num_blocks(10) // .num_blocks(10)
.num_layers(2) // .num_layers(2)
.outer_dim(1) // .outer_dim(1)
.page_size(4) // .page_size(4)
.inner_dim(13) // .inner_dim(13)
.build() // .build()
.unwrap(); // .unwrap();
let layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap(); // let layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap();
let layout = Arc::new(layout); // let layout = Arc::new(layout);
// Create a channel for returning blocks // // Create a channel for returning blocks
let (return_tx, _return_rx) = tokio::sync::mpsc::unbounded_channel(); // let (return_tx, _return_rx) = tokio::sync::mpsc::unbounded_channel();
// Create a block and wrap it in a MutableBlock // // Create a block and wrap it in a MutableBlock
let block_data = BlockData::new(layout.clone(), 0, 42, 0); // let block_data = BlockData::new(layout.clone(), 0, 42, 0);
let block = Block::new(block_data, BasicMetadata::default()).unwrap(); // let block = Block::new(block_data.into(), BasicMetadata::default()).unwrap();
let mut mutable_block = MutableBlock::new(block, return_tx.clone()); // let mut mutable_block = MutableBlock::new(block, return_tx.clone());
// Test is_fully_contiguous() // // Test is_fully_contiguous()
assert!(mutable_block.is_fully_contiguous()); // assert!(mutable_block.is_fully_contiguous());
// Test num_layers() // // Test num_layers()
assert_eq!(mutable_block.num_layers(), 2); // assert_eq!(mutable_block.num_layers(), 2);
// Test layer_view() // // Test layer_view()
let layer_view = mutable_block.layer_view(0, 0).unwrap(); // let layer_view = mutable_block.layer_view(0, 0).unwrap();
assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes // assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
assert!(!unsafe { layer_view.as_ptr() }.is_null()); // assert!(!unsafe { layer_view.as_ptr() }.is_null());
// Test layer_view_mut() // // Test layer_view_mut()
let mut layer_view_mut = mutable_block.layer_view_mut(1, 0).unwrap(); // let mut layer_view_mut = mutable_block.layer_view_mut(1, 0).unwrap();
assert_eq!(layer_view_mut.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes // assert_eq!(layer_view_mut.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
assert!(!unsafe { layer_view_mut.as_mut_ptr() }.is_null()); // assert!(!unsafe { layer_view_mut.as_mut_ptr() }.is_null());
// Test block_view() // // Test block_view()
let block_view = mutable_block.block_view().unwrap(); // let block_view = mutable_block.block_view().unwrap();
assert_eq!(block_view.size(), 2 * 4 * 13 * 2); // num_layers x page_size x inner_dim x dtype_bytes // assert_eq!(block_view.size(), 2 * 4 * 13 * 2); // num_layers x page_size x inner_dim x dtype_bytes
assert!(!unsafe { block_view.as_ptr() }.is_null()); // assert!(!unsafe { block_view.as_ptr() }.is_null());
// Test block_view_mut() // // Test block_view_mut()
let mut block_view_mut = mutable_block.block_view_mut().unwrap(); // let mut block_view_mut = mutable_block.block_view_mut().unwrap();
assert_eq!(block_view_mut.size(), 2 * 4 * 13 * 2); // num_layers x page_size x inner_dim x dtype_bytes // assert_eq!(block_view_mut.size(), 2 * 4 * 13 * 2); // num_layers x page_size x inner_dim x dtype_bytes
assert!(!unsafe { block_view_mut.as_mut_ptr() }.is_null()); // assert!(!unsafe { block_view_mut.as_mut_ptr() }.is_null());
tracing::info!("MutableBlock BlockDataExt tests completed successfully"); // tracing::info!("MutableBlock BlockDataExt tests completed successfully");
} // }
#[test] // #[test]
fn test_immutable_block_data_ext() { // fn test_immutable_block_data_ext() {
init_logging(); // init_logging();
// Create a layout with multiple layers and blocks for testing all methods // // Create a layout with multiple layers and blocks for testing all methods
let config = LayoutConfig::builder() // let config = LayoutConfig::builder()
.num_blocks(10) // .num_blocks(10)
.num_layers(2) // .num_layers(2)
.outer_dim(1) // .outer_dim(1)
.page_size(4) // .page_size(4)
.inner_dim(13) // .inner_dim(13)
.build() // .build()
.unwrap(); // .unwrap();
let layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap(); // let layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap();
let layout = Arc::new(layout); // let layout = Arc::new(layout);
// Create a channel for returning blocks // // Create a channel for returning blocks
let (return_tx, _return_rx) = tokio::sync::mpsc::unbounded_channel(); // let (return_tx, _return_rx) = tokio::sync::mpsc::unbounded_channel();
// Create a block and wrap it in a MutableBlock // // Create a block and wrap it in a MutableBlock
let block_data = BlockData::new(layout.clone(), 0, 42, 0); // let block_data = BlockData::new(layout.clone(), 0, 42, 0);
let block = Block::new(block_data, BasicMetadata::default()).unwrap(); // let block = Block::new(block_data, BasicMetadata::default()).unwrap();
let mutable_block = MutableBlock::new(block, return_tx.clone()); // let mut mutable_block = MutableBlock::new(block, return_tx.clone());
// Wrap the mutable block in an Arc and create an ImmutableBlock from it // let tbs = TokenBlockSequence::new(Tokens::from(vec![0, 0, 0, 0]), 4, None);
let arc_mutable_block = Arc::new(mutable_block); // let token_block = tbs.blocks().iter().next().unwrap();
let immutable_block = ImmutableBlock::new(arc_mutable_block);
// mutable_block
// Test is_fully_contiguous() // .apply_token_block(token_block.clone())
assert!(immutable_block.is_fully_contiguous()); // .unwrap();
// Test num_layers() // // Wrap the mutable block in an Arc and create an ImmutableBlock from it
assert_eq!(immutable_block.num_layers(), 2); // let arc_mutable_block = Arc::new(mutable_block);
// let immutable_block = ImmutableBlock::new(arc_mutable_block);
// Test layer_view()
let layer_view = immutable_block.layer_view(0, 0).unwrap(); // // Test is_fully_contiguous()
assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes // assert!(immutable_block.is_fully_contiguous());
assert!(!unsafe { layer_view.as_ptr() }.is_null());
// // Test num_layers()
// Test block_view() // assert_eq!(immutable_block.num_layers(), 2);
let block_view = immutable_block.block_view().unwrap();
assert_eq!(block_view.size(), 2 * 4 * 13 * 2); // num_layers x page_size x inner_dim x dtype_bytes // // Test layer_view()
assert!(!unsafe { block_view.as_ptr() }.is_null()); // let layer_view = immutable_block.layer_view(0, 0).unwrap();
// assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
// Test that mutable methods return errors // assert!(!unsafe { layer_view.as_ptr() }.is_null());
let mut mut_immutable_block = immutable_block; // We need a mutable reference for these tests
// // Test block_view()
let layer_view_mut_res = mut_immutable_block.layer_view_mut(0, 0); // let block_view = immutable_block.block_view().unwrap();
assert!(layer_view_mut_res.is_err()); // assert_eq!(block_view.size(), 2 * 4 * 13 * 2); // num_layers x page_size x inner_dim x dtype_bytes
if let Err(BlockError::InvalidState(msg)) = layer_view_mut_res { // assert!(!unsafe { block_view.as_ptr() }.is_null());
assert!(msg.contains("immutable block"));
} else { // // Test that mutable methods return errors
panic!("Expected InvalidState error"); // let mut mut_immutable_block = immutable_block; // We need a mutable reference for these tests
}
// let layer_view_mut_res = mut_immutable_block.layer_view_mut(0, 0);
let block_view_mut_res = mut_immutable_block.block_view_mut(); // assert!(layer_view_mut_res.is_err());
assert!(block_view_mut_res.is_err()); // if let Err(BlockError::InvalidState(msg)) = layer_view_mut_res {
if let Err(BlockError::InvalidState(msg)) = block_view_mut_res { // assert!(msg.contains("immutable block"));
assert!(msg.contains("immutable block")); // } else {
} else { // panic!("Expected InvalidState error");
panic!("Expected InvalidState error"); // }
}
// let block_view_mut_res = mut_immutable_block.block_view_mut();
tracing::info!("ImmutableBlock BlockDataExt tests completed successfully"); // assert!(block_view_mut_res.is_err());
} // if let Err(BlockError::InvalidState(msg)) = block_view_mut_res {
// assert!(msg.contains("immutable block"));
// } else {
// panic!("Expected InvalidState error");
// }
// tracing::info!("ImmutableBlock BlockDataExt tests completed successfully");
// }
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
pub mod local;
pub mod logical;
pub mod view;
pub use local::LocalBlockData as BlockData;
pub trait BlockDataExt<S: Storage>: Send + Sync + 'static + std::fmt::Debug {
/// The index of the block in the block set
fn block_id(&self) -> BlockId;
/// The identifier of the block set within the worker
fn block_set_id(&self) -> usize;
/// The identifier of the worker that owns the block
/// Note: If the block is a logical block, this will be the worker id of the worker
/// that owns the logical block, not the worker id of the worker that owns the physical block
/// because their could be multiple workers contributing to the same logical block.
fn worker_id(&self) -> WorkerID;
/// The storage type of the block
fn storage_type(&self) -> &StorageType;
/// Whether the block is fully contiguous
fn is_fully_contiguous(&self) -> bool;
/// Returns the number of layers in the block
fn num_layers(&self) -> usize;
/// The size of the page in the block
fn page_size(&self) -> usize;
/// Returns the number of outer dimensions in the block
fn num_outer_dims(&self) -> usize;
fn num_inner_dims(&self) -> usize;
/// Whether or not one can acquire read-only views to the block's storage
fn is_local(&self) -> Option<&dyn BlockDataViews<S>>;
/// Whether or not one can acquire mutable views to the block's storage
fn is_local_mut(&mut self) -> Option<&mut dyn BlockDataViews<S>>;
/// Get a read-only view of this block's storage for a layer
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
match self.is_local() {
Some(views) => views.local_layer_view(layer_idx, outer_idx),
None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks),
}
}
/// Get a mutable view of this block's storage for a layer
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>> {
match self.is_local_mut() {
Some(views) => views.local_layer_view_mut(layer_idx, outer_idx),
None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks),
}
}
/// Get a read-only view of this block's storage
fn block_view(&self) -> BlockResult<view::BlockView<S>> {
match self.is_local() {
Some(views) => views.local_block_view(),
None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks),
}
}
/// Get a mutable view of this block's storage
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
match self.is_local_mut() {
Some(views) => views.local_block_view_mut(),
None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks),
}
}
}
pub trait BlockDataViews<S: Storage> {
/// Get a read-only view of this block's storage for a layer
fn local_layer_view(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerView<S>>;
/// Get a mutable view of this block's storage for a layer
fn local_layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>>;
/// Get a read-only view of this block's storage
fn local_block_view(&self) -> BlockResult<view::BlockView<S>>;
/// Get a mutable view of this block's storage
fn local_block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>>;
}
pub trait BlockDataProvider: StorageTypeProvider {
type Locality: LocalityProvider;
fn block_data(&self) -> &impl BlockDataExt<Self::StorageType>;
}
pub trait BlockDataProviderMut: BlockDataProvider {
type Locality: LocalityProvider;
fn block_data_mut(&mut self) -> &mut impl BlockDataExt<Self::StorageType>;
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
/// Individual block storage
#[derive(Debug)]
pub struct LocalBlockData<S: Storage> {
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_idx: usize,
block_set_idx: usize,
worker_id: WorkerID,
}
impl<S: Storage> Clone for LocalBlockData<S> {
fn clone(&self) -> Self {
Self {
layout: self.layout.clone(),
block_idx: self.block_idx,
block_set_idx: self.block_set_idx,
worker_id: self.worker_id,
}
}
}
impl<S> LocalBlockData<S>
where
S: Storage,
{
/// Create a new block storage
pub(crate) fn new(
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_idx: usize,
block_set_idx: usize,
worker_id: WorkerID,
) -> Self {
Self {
layout,
block_idx,
block_set_idx,
worker_id,
}
}
}
impl<S: Storage> BlockDataExt<S> for LocalBlockData<S>
where
S: Storage,
{
#[inline(always)]
fn block_id(&self) -> BlockId {
self.block_idx
}
#[inline(always)]
fn block_set_id(&self) -> usize {
self.block_set_idx
}
#[inline(always)]
fn worker_id(&self) -> WorkerID {
self.worker_id
}
#[inline(always)]
fn storage_type(&self) -> &StorageType {
self.layout.storage_type()
}
fn is_fully_contiguous(&self) -> bool {
self.layout.layout_type() == LayoutType::FullyContiguous
}
fn num_layers(&self) -> usize {
self.layout.num_layers()
}
fn num_outer_dims(&self) -> usize {
self.layout.outer_dim()
}
fn num_inner_dims(&self) -> usize {
self.layout.inner_dim()
}
fn page_size(&self) -> usize {
self.layout.page_size()
}
fn is_local(&self) -> Option<&dyn BlockDataViews<S>> {
Some(self)
}
fn is_local_mut(&mut self) -> Option<&mut dyn BlockDataViews<S>> {
Some(self)
}
}
impl<S: Storage> BlockDataViews<S> for LocalBlockData<S> {
fn local_layer_view(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerView<S>> {
let mr = self
.layout
.memory_region(self.block_idx, layer_idx, outer_idx)?;
let storage_type = mr.storage_type();
unsafe { view::LayerView::new(self, mr.addr(), mr.size(), storage_type) }
}
fn local_layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>> {
let mr = self
.layout
.memory_region(self.block_idx, layer_idx, outer_idx)?;
unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size(), mr.storage_type()) }
}
fn local_block_view(&self) -> BlockResult<view::BlockView<S>> {
if self.is_fully_contiguous() {
let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
let offset = mr.addr();
let size = mr.size() * self.num_layers();
let storage_type = mr.storage_type();
unsafe { view::BlockView::new(self, offset, size, storage_type) }
} else {
Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(),
))
}
}
fn local_block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
if self.is_fully_contiguous() {
let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
let offset = mr.addr();
let size = mr.size() * self.num_layers();
let storage_type = mr.storage_type();
unsafe { view::BlockViewMut::new(self, offset, size, storage_type) }
} else {
Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(),
))
}
}
}
impl<S: Storage> StorageTypeProvider for LocalBlockData<S> {
type StorageType = S;
}
impl<S: Storage> BlockDataProvider for LocalBlockData<S> {
type Locality = locality::Local;
fn block_data(&self) -> &impl BlockDataExt<Self::StorageType> {
self
}
}
impl<S: Storage> BlockDataProviderMut for LocalBlockData<S> {
type Locality = locality::Local;
fn block_data_mut(&mut self) -> &mut impl BlockDataExt<Self::StorageType> {
self
}
}
impl<S: Storage> Local for LocalBlockData<S> {}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
pub mod distributed_leader_worker;
pub mod null;
use crate::block_manager::block::{
transfer::{TransferContext, TransferError, WriteToStrategy},
BlockDataProvider, ReadableBlock, WritableBlock,
};
use crate::block_manager::locality::Logical;
use crate::block_manager::storage::{self, nixl::NixlDescriptor};
use tokio::sync::oneshot;
pub enum LogicalKinds {
Simple,
Sharded,
}
pub trait LogicalResources: Clone + Send + Sync + 'static + std::fmt::Debug {
fn handle_transfer<RB, WB>(
&self,
sources: &[RB],
targets: &mut [WB],
ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Logical<Self>>,
WB: WritableBlock + BlockDataProviderMut<Locality = Logical<Self>>;
}
/// Individual block storage - cannot be cloned to ensure uniqueness
#[derive(Debug)]
pub struct LogicalBlockData<S: Storage, R: LogicalResources> {
block_id: BlockId,
block_set_id: usize,
worker_id: WorkerID,
resources: Arc<R>,
storage_type: StorageType,
storage: std::marker::PhantomData<S>,
page_size: usize,
}
impl<S: Storage, R: LogicalResources> LogicalBlockData<S, R> {
pub fn new(
block_id: BlockId,
block_set_id: usize,
worker_id: WorkerID,
resources: Arc<R>,
storage_type: StorageType,
page_size: usize,
) -> Self {
Self {
block_id,
block_set_id,
worker_id,
resources,
storage_type,
storage: std::marker::PhantomData,
page_size,
}
}
pub fn resources(&self) -> Arc<R> {
self.resources.clone()
}
}
impl<S: Storage, R: LogicalResources> BlockDataExt<S> for LogicalBlockData<S, R> {
fn block_id(&self) -> BlockId {
self.block_id
}
fn block_set_id(&self) -> usize {
self.block_set_id
}
fn worker_id(&self) -> WorkerID {
self.worker_id
}
fn storage_type(&self) -> &StorageType {
&self.storage_type
}
fn is_fully_contiguous(&self) -> bool {
unimplemented!()
}
fn num_layers(&self) -> usize {
unimplemented!()
}
/// Even though the block is logical, we still need to know this for the token block stuff.
fn page_size(&self) -> usize {
self.page_size
}
fn num_outer_dims(&self) -> usize {
unimplemented!()
}
fn num_inner_dims(&self) -> usize {
unimplemented!()
}
fn is_local(&self) -> Option<&dyn BlockDataViews<S>> {
None
}
fn is_local_mut(&mut self) -> Option<&mut dyn BlockDataViews<S>> {
None
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::block_manager::distributed::{BlockTransferPool, BlockTransferRequest, KvbmLeader};
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
type TransferRequest = (BlockTransferRequest, oneshot::Sender<()>);
#[derive(Clone)]
pub struct DistributedLeaderWorkerResources {
/// Make this an option to make testing easier.
// TODO(jthomson04): We should be using NullResources for this.
transfer_tx: Option<mpsc::UnboundedSender<TransferRequest>>,
}
impl std::fmt::Debug for DistributedLeaderWorkerResources {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DistributedLeaderWorkerResources").finish()
}
}
impl DistributedLeaderWorkerResources {
pub fn new(
leader: Option<Arc<KvbmLeader>>,
cancel_token: CancellationToken,
) -> anyhow::Result<Self> {
if let Some(leader) = leader {
let (transfer_tx, transfer_rx) = mpsc::unbounded_channel();
CriticalTaskExecutionHandle::new(
move |cancel_token| async move {
Self::worker(leader, transfer_rx, cancel_token).await
},
cancel_token,
"DistributedLeaderWorkerResources",
)
.map_err(|e| anyhow::anyhow!("Failed to create DistributedLeaderWorkerResources: {}", e))?.detach();
Ok(Self {
transfer_tx: Some(transfer_tx),
})
} else {
Ok(Self { transfer_tx: None })
}
}
fn get_pool<S: Storage>(data: &impl BlockDataExt<S>) -> BlockTransferPool {
match data.storage_type() {
StorageType::Device(_) => BlockTransferPool::Device,
StorageType::Pinned => BlockTransferPool::Host,
StorageType::Disk(_) => BlockTransferPool::Disk,
_ => panic!("Invalid storage type"),
}
}
async fn worker(
leader: Arc<KvbmLeader>,
mut transfer_rx: mpsc::UnboundedReceiver<TransferRequest>,
cancel_token: CancellationToken,
) -> anyhow::Result<()> {
loop {
tokio::select! {
Some(request) = transfer_rx.recv() => {
let (request, notify_tx) = request;
let rx = leader.transfer_blocks_request(request).await?;
tokio::spawn(async move {
rx.await.unwrap();
let _ = notify_tx.send(());
});
}
_ = cancel_token.cancelled() => {
break;
}
}
}
Ok(())
}
}
impl LogicalResources for DistributedLeaderWorkerResources {
fn handle_transfer<RB, WB>(
&self,
sources: &[RB],
targets: &mut [WB],
// TODO: This transfer context is only ever used in the `Local` locality.
_ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Logical<Self>>,
WB: WritableBlock + BlockDataProviderMut<Locality = Logical<Self>>,
{
if let Some(transfer_tx) = &self.transfer_tx {
let source_pool = Self::get_pool(sources[0].block_data());
let target_pool = Self::get_pool(targets[0].block_data());
let source_idxs = sources.iter().map(|source| source.block_data().block_id());
let target_idxs = targets.iter().map(|target| target.block_data().block_id());
let request = BlockTransferRequest::new(
source_pool,
target_pool,
source_idxs.zip(target_idxs).collect(),
);
let (tx, rx) = oneshot::channel();
transfer_tx.send((request, tx)).unwrap();
Ok(rx)
} else {
panic!("Block transfer functionality is disabled.");
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
#[derive(Debug, Clone)]
pub struct NullResources;
impl LogicalResources for NullResources {
fn handle_transfer<RB, WB>(
&self,
_sources: &[RB],
_targets: &mut [WB],
_ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Logical<Self>>,
WB: WritableBlock + BlockDataProviderMut<Locality = Logical<Self>>,
{
panic!("Null resources cannot be used for transfers");
}
}
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
//! and their storage. It handles the relationship between storage, layout, //! and their storage. It handles the relationship between storage, layout,
//! and individual blocks. //! and individual blocks.
use super::{BlockData, BlockError, Storage}; use super::{BlockDataExt, BlockError, Storage};
use crate::block_manager::storage::StorageType;
pub trait Kind: std::marker::Sized + std::fmt::Debug + Clone + Copy + Send + Sync {} pub trait Kind: std::marker::Sized + std::fmt::Debug + Clone + Copy + Send + Sync {}
...@@ -40,9 +41,10 @@ pub type LayerViewMut<'a, S> = MemoryViewMut<'a, S, LayerKind>; ...@@ -40,9 +41,10 @@ pub type LayerViewMut<'a, S> = MemoryViewMut<'a, S, LayerKind>;
/// Storage view that provides safe access to a region of storage /// Storage view that provides safe access to a region of storage
#[derive(Debug)] #[derive(Debug)]
pub struct MemoryView<'a, S: Storage, K: Kind> { pub struct MemoryView<'a, S: Storage, K: Kind> {
_block_data: &'a BlockData<S>, _block_data: &'a dyn BlockDataExt<S>,
addr: usize, addr: usize,
size: usize, size: usize,
storage_type: StorageType,
kind: std::marker::PhantomData<K>, kind: std::marker::PhantomData<K>,
} }
...@@ -58,14 +60,16 @@ where ...@@ -58,14 +60,16 @@ where
/// - addr + size <= storage.size() /// - addr + size <= storage.size()
/// - The view does not outlive the storage /// - The view does not outlive the storage
pub(crate) unsafe fn new( pub(crate) unsafe fn new(
_block_data: &'a BlockData<S>, _block_data: &'a dyn BlockDataExt<S>,
addr: usize, addr: usize,
size: usize, size: usize,
storage_type: StorageType,
) -> Result<Self, BlockError> { ) -> Result<Self, BlockError> {
Ok(Self { Ok(Self {
_block_data, _block_data,
addr, addr,
size, size,
storage_type,
kind: std::marker::PhantomData, kind: std::marker::PhantomData,
}) })
} }
...@@ -89,9 +93,10 @@ where ...@@ -89,9 +93,10 @@ where
/// Mutable storage view that provides exclusive access to a region of storage /// Mutable storage view that provides exclusive access to a region of storage
#[derive(Debug)] #[derive(Debug)]
pub struct MemoryViewMut<'a, S: Storage, K: Kind> { pub struct MemoryViewMut<'a, S: Storage, K: Kind> {
_block_data: &'a mut BlockData<S>, _block_data: &'a mut dyn BlockDataExt<S>,
addr: usize, addr: usize,
size: usize, size: usize,
storage_type: StorageType,
kind: std::marker::PhantomData<K>, kind: std::marker::PhantomData<K>,
} }
...@@ -104,14 +109,16 @@ impl<'a, S: Storage, K: Kind> MemoryViewMut<'a, S, K> { ...@@ -104,14 +109,16 @@ impl<'a, S: Storage, K: Kind> MemoryViewMut<'a, S, K> {
/// - The view does not outlive the storage /// - The view does not outlive the storage
/// - No other views exist for this region /// - No other views exist for this region
pub(crate) unsafe fn new( pub(crate) unsafe fn new(
_block_data: &'a mut BlockData<S>, _block_data: &'a mut dyn BlockDataExt<S>,
addr: usize, addr: usize,
size: usize, size: usize,
storage_type: StorageType,
) -> Result<Self, BlockError> { ) -> Result<Self, BlockError> {
Ok(Self { Ok(Self {
_block_data, _block_data,
addr, addr,
size, size,
storage_type,
kind: std::marker::PhantomData, kind: std::marker::PhantomData,
}) })
} }
...@@ -138,6 +145,7 @@ mod nixl { ...@@ -138,6 +145,7 @@ mod nixl {
use super::super::nixl::*; use super::super::nixl::*;
pub use crate::block_manager::storage::StorageType;
pub use nixl_sys::{MemType, MemoryRegion, NixlDescriptor}; pub use nixl_sys::{MemType, MemoryRegion, NixlDescriptor};
impl<S: Storage, K: Kind> MemoryRegion for MemoryView<'_, S, K> { impl<S: Storage, K: Kind> MemoryRegion for MemoryView<'_, S, K> {
...@@ -156,17 +164,16 @@ mod nixl { ...@@ -156,17 +164,16 @@ mod nixl {
K: Kind, K: Kind,
{ {
fn mem_type(&self) -> MemType { fn mem_type(&self) -> MemType {
self._block_data.layout.storage_type().nixl_mem_type() self._block_data.storage_type().nixl_mem_type()
} }
fn device_id(&self) -> u64 { fn device_id(&self) -> u64 {
self._block_data match self.storage_type {
.layout StorageType::System | StorageType::Pinned => 0,
.storage() StorageType::Device(device_id) => device_id as u64,
.into_iter() StorageType::Disk(fd) => fd,
.next() _ => panic!("Invalid storage type"),
.unwrap() }
.device_id()
} }
} }
...@@ -186,17 +193,16 @@ mod nixl { ...@@ -186,17 +193,16 @@ mod nixl {
K: Kind, K: Kind,
{ {
fn mem_type(&self) -> MemType { fn mem_type(&self) -> MemType {
self._block_data.layout.storage_type().nixl_mem_type() self._block_data.storage_type().nixl_mem_type()
} }
fn device_id(&self) -> u64 { fn device_id(&self) -> u64 {
self._block_data match self.storage_type {
.layout StorageType::System | StorageType::Pinned => 0,
.storage() StorageType::Device(device_id) => device_id as u64,
.into_iter() StorageType::Disk(fd) => fd,
.next() _ => panic!("Invalid storage type"),
.unwrap() }
.device_id()
} }
} }
...@@ -208,10 +214,10 @@ mod nixl { ...@@ -208,10 +214,10 @@ mod nixl {
/// Creates an immutable NIXL memory descriptor from this view. /// Creates an immutable NIXL memory descriptor from this view.
pub fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'a, K, IsImmutable> { pub fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'a, K, IsImmutable> {
NixlMemoryDescriptor::new( NixlMemoryDescriptor::new(
self.addr as u64, // Address from the view self.addr as u64, // Address from the view
self.size(), // Size from the view self.size(), // Size from the view
NixlDescriptor::mem_type(self), // Delegate to self's NixlDescriptor impl self.mem_type(),
NixlDescriptor::device_id(self), // Delegate to self's NixlDescriptor impl self.device_id(),
) )
} }
} }
...@@ -228,8 +234,8 @@ mod nixl { ...@@ -228,8 +234,8 @@ mod nixl {
NixlMemoryDescriptor::new( NixlMemoryDescriptor::new(
self.addr as u64, self.addr as u64,
self.size(), self.size(),
NixlDescriptor::mem_type(self), // Delegate to self's NixlDescriptor impl self.mem_type(),
NixlDescriptor::device_id(self), // Delegate to self's NixlDescriptor impl self.device_id(),
) )
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod local;
pub mod logical;
pub use local::LocalBlockDataFactory;
use crate::block_manager::LayoutConfig;
use super::*;
use derive_getters::Dissolve;
/// Core trait for block factories that can create blocks with specific locality and storage
///
/// This trait provides the foundation for creating blocks with different locality providers
/// (Local, Logical, etc.) and storage types.
pub trait BlockFactory<S: Storage, L: LocalityProvider> {
/// Create block data for a specific block ID
/// This does not consume the factory and can be called multiple times
fn create_block_data(&self, block_id: BlockId) -> BlockResult<L::BlockData<S>>;
/// Create a single block with default metadata
/// This does not consume the factory and can be called multiple times
fn create_block<M: BlockMetadata + Default>(
&self,
block_id: BlockId,
) -> BlockResult<Block<S, L, M>> {
let block_data = self.create_block_data(block_id)?;
Block::new(block_data, M::default())
}
/// Create a single block with the given metadata
/// This does not consume the factory and can be called multiple times
fn create_block_with_metadata<M: BlockMetadata>(
&self,
block_id: BlockId,
metadata: M,
) -> BlockResult<Block<S, L, M>> {
let block_data = self.create_block_data(block_id)?;
Block::new(block_data, metadata)
}
/// Get the number of blocks this factory can create
fn num_blocks(&self) -> usize;
/// Get the layout configuration information
fn layout_config(&self) -> &LayoutConfig;
}
/// Extension trait for factories that can produce all blocks at once
pub trait IntoBlocks<S: Storage, L: LocalityProvider>: BlockFactory<S, L> + Sized {
/// Consume the factory and create all blocks with default metadata
fn into_blocks<M: BlockMetadata + Default>(self) -> BlockResult<Vec<Block<S, L, M>>> {
let num_blocks = self.num_blocks();
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let block = self.create_block(block_idx)?;
blocks.push(block);
}
Ok(blocks)
}
/// Consume the factory and create all blocks with the given metadata value
fn into_blocks_with_metadata<M: BlockMetadata + Clone>(
self,
metadata: M,
) -> BlockResult<Vec<Block<S, L, M>>> {
let num_blocks = self.num_blocks();
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let block = self.create_block_with_metadata(block_idx, metadata.clone())?;
blocks.push(block);
}
Ok(blocks)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
#[derive(Debug, Clone, Dissolve)]
pub struct LocalBlockDataFactory<S: Storage> {
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_set_idx: usize,
worker_id: WorkerID,
}
impl<S: Storage> LocalBlockDataFactory<S> {
pub fn new(
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_set_idx: usize,
worker_id: WorkerID,
) -> Self {
Self {
layout,
block_set_idx,
worker_id,
}
}
}
impl<S: Storage> BlockFactory<S, locality::Local> for LocalBlockDataFactory<S> {
fn create_block_data(&self, block_idx: BlockId) -> BlockResult<BlockData<S>> {
if block_idx >= self.layout.num_blocks() {
return Err(BlockError::InvalidBlockID(block_idx));
}
let data = BlockData::new(
self.layout.clone(),
block_idx,
self.block_set_idx,
self.worker_id,
);
Ok(data)
}
fn num_blocks(&self) -> usize {
self.layout.num_blocks()
}
fn layout_config(&self) -> &LayoutConfig {
self.layout.config()
}
}
impl<S: Storage> IntoBlocks<S, locality::Local> for LocalBlockDataFactory<S> {}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::block_manager::locality::{Logical, LogicalBlockData, LogicalResources};
#[derive(Debug)]
pub struct LogicalBlockFactory<S: Storage, R: LogicalResources> {
layout_config: Arc<LayoutConfig>,
block_set_idx: usize,
worker_id: WorkerID,
resources: Arc<R>,
storage_type: StorageType,
storage: std::marker::PhantomData<S>,
}
impl<S: Storage, R: LogicalResources> LogicalBlockFactory<S, R> {
pub fn new(
layout_config: Arc<LayoutConfig>,
block_set_idx: usize,
worker_id: WorkerID,
resources: Arc<R>,
storage_type: StorageType,
) -> Self {
Self {
layout_config,
block_set_idx,
worker_id,
resources,
storage_type,
storage: std::marker::PhantomData,
}
}
}
impl<S: Storage, R: LogicalResources> BlockFactory<S, Logical<R>> for LogicalBlockFactory<S, R> {
fn create_block_data(&self, block_idx: BlockId) -> BlockResult<LogicalBlockData<S, R>> {
if block_idx >= self.num_blocks() {
return Err(BlockError::InvalidBlockID(block_idx));
}
let data = LogicalBlockData::new(
block_idx,
self.block_set_idx,
self.worker_id,
self.resources.clone(),
self.storage_type,
self.layout_config.page_size,
);
Ok(data)
}
fn num_blocks(&self) -> usize {
self.layout_config.num_blocks
}
fn layout_config(&self) -> &LayoutConfig {
&self.layout_config
}
}
impl<S: Storage, R: LogicalResources> IntoBlocks<S, Logical<R>> for LogicalBlockFactory<S, R> {}
#[cfg(test)]
mod tests {
use crate::block_manager::block::data::logical::null::NullResources;
use crate::block_manager::{ManagedBlockPool, PinnedStorage};
use super::*;
const TEST_BLOCK_SET_ID: usize = 42;
const TEST_WORKER_ID: WorkerID = 1337;
#[tokio::test]
async fn test_logical_block_factory() {
let layout_config = LayoutConfig::builder()
.num_blocks(10)
.page_size(16)
.num_layers(3)
.outer_dim(2)
.inner_dim(8192)
.dtype_width_bytes(2)
.build()
.unwrap();
let factory = LogicalBlockFactory::<PinnedStorage, NullResources>::new(
Arc::new(layout_config),
TEST_BLOCK_SET_ID,
TEST_WORKER_ID,
Arc::new(NullResources),
StorageType::Pinned,
);
let block_data = factory.create_block_data(0).unwrap();
assert_eq!(block_data.block_id(), 0);
assert_eq!(block_data.block_set_id(), TEST_BLOCK_SET_ID);
assert_eq!(block_data.worker_id(), TEST_WORKER_ID);
assert_eq!(block_data.storage_type(), &StorageType::Pinned);
let _resources = block_data.resources();
let blocks = factory
.into_blocks_with_metadata(BasicMetadata::default())
.unwrap();
ManagedBlockPool::builder().blocks(blocks).build().unwrap();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// todo: move this up one level to be on par with state and block
// locality is primarily focused on the locality of the block data; however,
// the choice of locality permeates the entire block manager.
//
// by moving up a level, it will make more sense use a kvbm level config object
// and kvbm state resources object to construct a locality aware block factory
//
// note: a block factory is also a block data factory
//
// factories can be turned into pools to implement the block pool and kvbm top-level
// interface; however, it can also be used to directly construct block data objects
// which can be used by leader-driven workers which do not have full block pools.
use super::*;
use crate::block_manager::block::transfer::{
handle_local_transfer, TransferContext, TransferError, WriteToStrategy,
};
use crate::block_manager::storage::{self, nixl::NixlDescriptor};
use std::any::Any;
use tokio::sync::oneshot;
pub trait LocalityProvider: Send + Sync + 'static + std::fmt::Debug {
// type Disk: BlockDataExt<DiskStorage>;
// type Host: BlockDataExt<PinnedStorage>;
// type Device: BlockDataExt<DeviceStorage>;
type BlockData<S: Storage>: BlockDataExt<S>;
fn handle_transfer<RB, WB>(
_sources: &[RB],
_targets: &mut [WB],
_ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Self>,
WB: WritableBlock + BlockDataProviderMut<Locality = Self>;
}
/// Local locality provider for direct memory access
#[derive(Debug)]
pub struct Local;
impl LocalityProvider for Local {
type BlockData<S: Storage> = BlockData<S>;
fn handle_transfer<RB, WB>(
sources: &[RB],
targets: &mut [WB],
ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Self>,
WB: WritableBlock + BlockDataProviderMut<Locality = Self>,
{
handle_local_transfer(sources, targets, ctx)
}
}
pub use crate::block_manager::block::data::logical::{LogicalBlockData, LogicalResources};
/// General logical locality for future RPC-based transfers
#[derive(Debug)]
pub struct Logical<R: LogicalResources> {
_resources: std::marker::PhantomData<R>,
}
impl<R: LogicalResources> Logical<R> {
// TODO(jthomson04): Refactor these???
fn load_resources<B: BlockDataProvider<Locality = Logical<R>>>(blocks: &[B]) -> Vec<Arc<R>> {
blocks
.iter()
.map(|block| {
let any_block = block.block_data() as &dyn Any;
// TODO: Downcasting and unwrapping like this is atrocious...
let logical_block = any_block
.downcast_ref::<LogicalBlockData<<B as StorageTypeProvider>::StorageType, R>>()
.unwrap();
logical_block.resources()
})
.collect()
}
fn load_resources_mut<B: BlockDataProviderMut<Locality = Logical<R>>>(
blocks: &mut [B],
) -> Vec<Arc<R>> {
blocks
.iter_mut()
.map(|block| {
let any_block = block.block_data_mut() as &mut dyn Any;
let logical_block = any_block
.downcast_mut::<LogicalBlockData<<B as StorageTypeProvider>::StorageType, R>>()
.unwrap();
logical_block.resources()
})
.collect()
}
}
impl<R: LogicalResources> LocalityProvider for Logical<R> {
type BlockData<S: Storage> = LogicalBlockData<S, R>;
fn handle_transfer<RB, WB>(
sources: &[RB],
targets: &mut [WB],
ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Self>,
WB: WritableBlock + BlockDataProviderMut<Locality = Self>,
{
let source_resources = Self::load_resources(sources);
let target_resources = Self::load_resources_mut(targets);
let all_resources = source_resources
.into_iter()
.chain(target_resources)
.collect::<Vec<_>>();
// For now, assert that all resources between the source and target are the same
if !all_resources
.iter()
.all(|r| Arc::ptr_eq(r, &all_resources[0]))
{
return Err(anyhow::anyhow!("Resources used in a transfer must be the same!").into());
}
let common_resource = all_resources[0].clone();
common_resource.handle_transfer(sources, targets, ctx)
}
}
...@@ -93,6 +93,9 @@ impl BlockState { ...@@ -93,6 +93,9 @@ impl BlockState {
} }
} }
/// Apply an entry [TokenBlock] to the block.
/// The block must be in the reset state on entry. The block will transition to
/// the completed state after this call.
pub fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()> { pub fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()> {
match self { match self {
BlockState::Reset => { BlockState::Reset => {
......
...@@ -19,7 +19,6 @@ mod memcpy; ...@@ -19,7 +19,6 @@ mod memcpy;
mod nixl; mod nixl;
mod strategy; mod strategy;
use super::nixl::{IsMutable, NixlBlockDataImmutable, NixlBlockDataMutable, RemoteBlock};
use super::*; use super::*;
use crate::block_manager::storage::{ use crate::block_manager::storage::{
...@@ -29,6 +28,7 @@ use crate::block_manager::storage::{ ...@@ -29,6 +28,7 @@ use crate::block_manager::storage::{
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
use nixl_sys::NixlDescriptor;
use nixl_sys::XferOp::{Read, Write}; use nixl_sys::XferOp::{Read, Write};
use std::ops::Range; use std::ops::Range;
use tokio::sync::oneshot; use tokio::sync::oneshot;
...@@ -125,20 +125,21 @@ pub trait ReadFromStrategy<Source> { ...@@ -125,20 +125,21 @@ pub trait ReadFromStrategy<Source> {
impl<RB: ReadableBlock, WB: WritableBlock> WriteToStrategy<WB> for RB impl<RB: ReadableBlock, WB: WritableBlock> WriteToStrategy<WB> for RB
where where
<RB as ReadableBlock>::StorageType: Local + WriteToStrategy<<WB as WritableBlock>::StorageType>, <RB as StorageTypeProvider>::StorageType:
Local + WriteToStrategy<<WB as StorageTypeProvider>::StorageType>,
{ {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
<<RB as ReadableBlock>::StorageType as WriteToStrategy< <<RB as StorageTypeProvider>::StorageType as WriteToStrategy<
<WB as WritableBlock>::StorageType, <WB as StorageTypeProvider>::StorageType,
>>::write_to_strategy() >>::write_to_strategy()
} }
} }
impl<WB: WritableBlock, RB: ReadableBlock> ReadFromStrategy<RB> for WB impl<WB: WritableBlock, RB: ReadableBlock> ReadFromStrategy<RB> for WB
where where
<RB as ReadableBlock>::StorageType: Remote, <RB as StorageTypeProvider>::StorageType: Remote,
<WB as WritableBlock>::StorageType: NixlRegisterableStorage, <WB as StorageTypeProvider>::StorageType: NixlRegisterableStorage,
{ {
#[inline(always)] #[inline(always)]
fn read_from_strategy() -> TransferStrategy { fn read_from_strategy() -> TransferStrategy {
...@@ -146,478 +147,81 @@ where ...@@ -146,478 +147,81 @@ where
} }
} }
pub fn handle_local_transfer<RB, WB>(
sources: &[RB],
targets: &mut [WB],
ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + Local,
WB: WritableBlock,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
{
let (tx, rx) = oneshot::channel();
match RB::write_to_strategy() {
TransferStrategy::Memcpy => {
for (src, dst) in sources.iter().zip(targets.iter_mut()) {
// TODO: Unlike all other transfer strategies, this is fully blocking.
// We probably want some sort of thread pool to handle these.
memcpy::copy_block(src, dst)?;
}
tx.send(()).unwrap();
Ok(rx)
}
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D => {
for (src, dst) in sources.iter().zip(targets.iter_mut()) {
cuda::copy_block(src, dst, ctx.stream().as_ref(), RB::write_to_strategy())?;
}
ctx.cuda_event(tx)?;
Ok(rx)
}
TransferStrategy::Nixl(transfer_type) => {
let transfer_fut = nixl::write_blocks_to(sources, targets, &ctx, transfer_type)?;
ctx.async_rt_handle().spawn(async move {
transfer_fut.await;
tx.send(()).unwrap();
});
Ok(rx)
}
_ => Err(TransferError::IncompatibleTypes(format!(
"Unsupported copy strategy: {:?}",
RB::write_to_strategy()
))),
}
}
pub trait WriteTo<Target> { pub trait WriteTo<Target> {
fn write_to( fn write_to(
&self, &self,
dst: &mut Vec<Target>, dst: &mut Vec<Target>,
notify: bool,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<Option<oneshot::Receiver<()>>, TransferError>; ) -> Result<oneshot::Receiver<()>, TransferError>;
} }
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for Vec<Arc<RB>> impl<RB, WB, L: LocalityProvider> WriteTo<WB> for Vec<RB>
where where
RB: WriteToStrategy<WB> + Local, RB: ReadableBlock + WriteToStrategy<WB> + Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = L>,
WB: WritableBlock + BlockDataProviderMut<Locality = L>,
{ {
fn write_to( fn write_to(
&self, &self,
dst: &mut Vec<WB>, dst: &mut Vec<WB>,
notify: bool,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<Option<oneshot::Receiver<()>>, TransferError> { ) -> Result<oneshot::Receiver<()>, TransferError> {
let (tx, rx) = oneshot::channel(); L::handle_transfer(self, dst, ctx)
match RB::write_to_strategy() {
TransferStrategy::Memcpy => {
for (src, dst) in self.iter().zip(dst.iter_mut()) {
// TODO: Unlike all other transfer strategies, this is fully blocking.
// We probably want some sort of thread pool to handle these.
memcpy::copy_block(src.as_ref(), dst)?;
}
if notify {
tx.send(()).unwrap();
Ok(Some(rx))
} else {
Ok(None)
}
}
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D => {
for (src, dst) in self.iter().zip(dst.iter_mut()) {
cuda::copy_block(
src.as_ref(),
dst,
ctx.stream().as_ref(),
RB::write_to_strategy(),
)?;
}
if notify {
let (tx, rx) = oneshot::channel();
ctx.cuda_event(tx)?;
Ok(Some(rx))
} else {
Ok(None)
}
}
TransferStrategy::Nixl(transfer_type) => {
let transfer_fut = nixl::write_blocks_to(self, dst, &ctx, transfer_type)?;
if notify {
ctx.async_rt_handle().spawn(async move {
transfer_fut.await;
tx.send(()).unwrap();
});
Ok(Some(rx))
} else {
Ok(None)
}
}
_ => Err(TransferError::IncompatibleTypes(format!(
"Unsupported copy strategy: {:?}",
RB::write_to_strategy()
))),
}
} }
} }
#[derive(Default)]
pub struct GetXferRequestBuilder<
'xfer,
Source: BlockDataProvider,
Target: BlockDataProviderMut + Local,
> {
_src: Option<&'xfer [Source]>,
_dst: Option<&'xfer [Target]>,
}
// impl<'xfer, Source: BlockDataProvider, Target: BlockDataProviderMut + Local>
// GetXferRequestBuilder<'xfer, Source, Target>
// {
// fn new(state: Arc<BlockTransferEngineState>) -> Self {
// Self {
// src: None,
// dst: None,
// }
// }
// pub fn from(&mut self, local_or_remote_blocks: &'xfer [Target]) -> &mut Self {
// self.dst = Some(local_or_remote_blocks);
// self
// }
// pub fn to(&mut self, local_mutable_blocks: &'xfer [Source]) -> &mut Self {
// self.src = Some(local_mutable_blocks);
// self
// }
// }
pub struct PutXferRequestBuilder<
'xfer,
Source: BlockDataProvider + Local,
Target: BlockDataProviderMut,
> {
_src: Option<&'xfer [Source]>,
_dst: Option<&'xfer [Target]>,
}
// impl<'xfer, Source: BlockDataProvider + Local, Target: BlockDataProviderMut>
// PutXferRequestBuilder<'xfer, Source, Target>
// {
// fn new(state: Arc<BlockTransferEngineState>) -> Self {
// Self {
// src: None,
// dst: None,
// }
// }
// pub fn from(&mut self, local_blocks: &'xfer [Source]) -> &mut Self {
// self.src = Some(local_blocks);
// self
// }
// pub fn to(&mut self, local_or_remote: &'xfer [Target]) -> &mut Self {
// self.dst = Some(local_or_remote);
// self
// }
// }
// #[async_trait]
// impl<'xfer, Target: BlockDataProviderMut + Local>
// AsyncBlockTransferEngine<RemoteBlock<IsImmutable>, Target>
// for GetXferRequestBuilder<'xfer, RemoteBlock<IsImmutable>, Target>
// where
// Target: BlockDataProviderMut + Local + Send + Sync,
// {
// async fn execute(self) -> Result<()> {
// unimplemented!()
// }
// }
// #[async_trait]
// impl<'xfer, Source, Target> AsyncBlockTransferEngine<Source, Target>
// for GetXferRequestBuilder<'xfer, Source, Target>
// where
// Source: BlockDataProvider + Local + Send + Sync,
// Target: BlockDataProviderMut + Local + Send + Sync,
// {
// async fn execute(self) -> Result<()> {
// unimplemented!()
// }
// }
// pub trait BlockCopyTo<Target:BlockDataProviderMut + Local>: BlockDataProvider + Local {
// fn copy_blocks
#[async_trait]
pub trait AsyncBlockTransferEngine<Source: BlockDataProvider, Target: BlockDataProviderMut + Local>
{
async fn execute(self) -> anyhow::Result<()>;
}
pub trait BlockTransferEngineV1<Source: BlockDataProvider, Target: BlockDataProviderMut> {
fn prepare(&mut self) -> Result<(), TransferError> {
Ok(())
}
fn execute(self) -> Result<(), TransferError>;
}
// memcpy transfer engine
// - System -> System
// - Pinned -> Pinned
// cuda memcpy transfer engine
// - Pinned -> Device
// - Device -> Pinned
// - Device -> Device
// nixl memcpy transfer engine
// - NixlRegisterableStorage -> Nixl
// - Nixl -> NixlRegisterableStorage
// where System, Pinned, Device are NixlRegisterableStorage
// Placeholder for the actual transfer plan
#[derive(Debug)]
pub struct TransferRequestPut<
'a,
Source: BlockDataProvider + Local,
Destination: BlockDataProviderMut,
> {
sources: &'a [Source],
destinations: &'a mut [Destination],
}
// --- NIXL PUT Transfer Implementation ---
impl<Source> BlockTransferEngineV1<Source, RemoteBlock<IsMutable>>
for TransferRequestPut<'_, Source, RemoteBlock<IsMutable>>
where
Source: BlockDataProvider + Local, // + NixlBlockDataMutable<Source::StorageType>,
Source::StorageType: NixlRegisterableStorage,
{
fn execute(self) -> Result<(), TransferError> {
self.validate_counts()?;
tracing::info!("Executing NIXL PUT transfer request");
// TODO: Get NixlAgent handle
for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
let src_data = src_block.block_data(private::PrivateToken);
let src_nixl_desc = src_data.as_block_descriptor()?;
let dst_data = dst_block.block_data_mut(private::PrivateToken);
let dst_nixl_desc = dst_data.as_block_descriptor_mut()?;
// TODO: Perform NIXL PUT operation
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "NIXL PUT block");
tracing::trace!(src_desc = ?src_nixl_desc, dst_desc = ?dst_nixl_desc, "NIXL PUT block");
}
Ok(())
}
}
impl<'a, Source, Destination> TransferRequestPut<'a, Source, Destination>
where
Source: BlockDataProvider + Local,
Destination: BlockDataProviderMut,
{
pub fn new(
sources: &'a [Source],
destinations: &'a mut [Destination],
) -> Result<Self, TransferError> {
let transfer_request = Self {
sources,
destinations,
};
transfer_request.validate_counts()?;
Ok(transfer_request)
}
/// Validate blocks
///
/// For a put, we can have duplicate blocks on the source side, but all destinations must be unique
/// For all transfers, the source and destination block sets must be disjoint.
pub fn validate_blocks(&self) -> Result<(), TransferError> {
let mut src_set = std::collections::HashSet::new();
let mut dst_set = std::collections::HashSet::new();
for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter()) {
let src_data = src_block.block_data(private::PrivateToken);
let dst_data = dst_block.block_data(private::PrivateToken);
src_set.insert((
src_data.block_set_idx,
src_data.block_idx,
src_data.worker_id,
));
dst_set.insert((
dst_data.block_set_idx,
dst_data.block_idx,
dst_data.worker_id,
));
}
if dst_set.len() != self.destinations.len() {
return Err(TransferError::BuilderError(
"Duplicate destination blocks".to_string(),
));
}
// the intersection of src_set and dst_set must be empty
if !src_set.is_disjoint(&dst_set) {
return Err(TransferError::BuilderError(
"Duplicate one or more duplicate entries in source and destination list"
.to_string(),
));
}
Ok(())
}
/// Common validation for all PUT requests.
fn validate_counts(&self) -> Result<(), TransferError> {
if self.sources.len() != self.destinations.len() {
Err(TransferError::CountMismatch(
self.sources.len(),
self.destinations.len(),
))
} else if self.sources.is_empty() {
Err(TransferError::BuilderError(
"Sources cannot be empty".to_string(),
))
} else if self.destinations.is_empty() {
Err(TransferError::BuilderError(
"Destinations cannot be empty".to_string(),
))
} else {
Ok(())
}
}
}
// // --- Local Transfer Implementations ---
// // Local Pinned -> Pinned
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<PinnedStorage, MSource>,
// MutableBlock<PinnedStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Pinned -> Pinned");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy H2H or std::ptr::copy
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Pinned -> Device
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<PinnedStorage, MSource>,
// MutableBlock<DeviceStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Pinned -> Device");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy H2D
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Device -> Pinned
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<DeviceStorage, MSource>,
// MutableBlock<PinnedStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Device -> Pinned");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy D2H
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Device -> Device
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<DeviceStorage, MSource>,
// MutableBlock<DeviceStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Device -> Device");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy D2D
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// pub fn dispatch_copy_to<RB, WB>(
// src: &RB,
// dst: &mut WB,
// ctx: &TransferContext,
// ) -> Result<(), TransferError>
// where
// RB: ReadableBlock,
// WB: WritableBlock,
// // Ensure the necessary capability traits are implemented for the storage types
// // Note: These bounds aren't strictly *required* for the TypeId check,
// // but help ensure the backend calls will compile if a match occurs.
// // RB::Storage: SystemAccessible + CudaAccessible, // Might be too restrictive, apply within match arms
// // WB::Storage: SystemAccessible + CudaAccessible,
// {
// let src_type = src.storage_type_id();
// let dst_type = dst.storage_type_id();
// match (src_type, dst_type) {
// // === Memcpy Cases ===
// (s, d)
// if (s == TypeId::of::<SystemStorage>() && d == TypeId::of::<SystemStorage>())
// || (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<SystemStorage>())
// || (s == TypeId::of::<SystemStorage>() && d == TypeId::of::<PinnedStorage>())
// || (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<PinnedStorage>()) =>
// {
// memcpy::memcpy_block(src, dst)
// }
// // === CUDA Cases ===
// (s, d)
// if (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<DeviceStorage>())
// || (s == TypeId::of::<DeviceStorage>() && d == TypeId::of::<PinnedStorage>())
// || (s == TypeId::of::<DeviceStorage>() && d == TypeId::of::<DeviceStorage>()) =>
// {
// cuda::cuda_memcpy_block(src, dst, ctx.stream().as_ref())
// // let stream = stream.ok_or_else(|| {
// // TransferError::BuilderError("CUDA stream required for this transfer".into())
// // })?;
// // if is_cuda_compatible::<RB, WB>() {
// // tracing::debug!("Dispatching copy using CUDA");
// // cuda::cuda_memcpy_block(src_provider, dst_provider, stream) // Assumes cuda_memcpy_block is generic
// // } else {
// // Err(TransferError::IncompatibleTypes(
// // "CUDA copy requires CudaAccessible storage".into(),
// // ))
// // }
// }
// // === NIXL Cases ===
// (s, d)
// if d == TypeId::of::<NixlStorage>()
// && (s == TypeId::of::<SystemStorage>()
// || s == TypeId::of::<PinnedStorage>()
// || s == TypeId::of::<DeviceStorage>()) =>
// {
// unimplemented!()
// // tracing::debug!("Dispatching copy using NIXL PUT");
// // // TODO: Implement NIXL PUT logic
// // // You might need a specific NIXL transfer function here.
// // // Example: nixl::nixl_put_block(src_provider, dst_provider)
// // Err(TransferError::ExecutionError(
// // "NIXL PUT not yet implemented".into(),
// // ))
// }
// // TODO: Add NIXL GET cases (Nixl -> System/Pinned/Device)
// // === Error Case ===
// _ => Err(TransferError::IncompatibleTypes(format!(
// "Unsupported storage combination for copy: {:?} -> {:?}",
// std::any::type_name::<<RB as ReadableBlock>::StorageType>(), // Requires nightly or use debug print
// std::any::type_name::<<WB as WritableBlock>::StorageType>()
// ))),
// }
// }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -50,8 +50,8 @@ where ...@@ -50,8 +50,8 @@ where
Source: BlockDataProvider, Source: BlockDataProvider,
Destination: BlockDataProviderMut, Destination: BlockDataProviderMut,
{ {
let src_data = sources.block_data(private::PrivateToken); let src_data = sources.block_data();
let dst_data = destinations.block_data_mut(private::PrivateToken); let dst_data = destinations.block_data_mut();
let memcpy_fn = cuda_memcpy_fn_ptr(&strategy)?; let memcpy_fn = cuda_memcpy_fn_ptr(&strategy)?;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
...@@ -100,8 +100,8 @@ where ...@@ -100,8 +100,8 @@ where
Source: BlockDataProvider, Source: BlockDataProvider,
Destination: BlockDataProviderMut, Destination: BlockDataProviderMut,
{ {
let src_data = sources.block_data(private::PrivateToken); let src_data = sources.block_data();
let dst_data = destinations.block_data_mut(private::PrivateToken); let dst_data = destinations.block_data_mut();
let memcpy_fn = cuda_memcpy_fn_ptr(&strategy)?; let memcpy_fn = cuda_memcpy_fn_ptr(&strategy)?;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment