Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
80256acf
Unverified
Commit
80256acf
authored
May 20, 2025
by
Ryan Olson
Committed by
GitHub
May 20, 2025
Browse files
feat: adding outer dimension to isolate k/v blocks (#1126)
parent
7e452a2e
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
325 additions
and
191 deletions
+325
-191
lib/bindings/python/rust/llm/block_manager.rs
lib/bindings/python/rust/llm/block_manager.rs
+3
-1
lib/bindings/python/tests/test_block_manager.py
lib/bindings/python/tests/test_block_manager.py
+36
-51
lib/llm/Cargo.toml
lib/llm/Cargo.toml
+3
-0
lib/llm/src/block_manager.rs
lib/llm/src/block_manager.rs
+3
-0
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+86
-33
lib/llm/src/block_manager/block/transfer/cuda.rs
lib/llm/src/block_manager/block/transfer/cuda.rs
+14
-12
lib/llm/src/block_manager/block/transfer/memcpy.rs
lib/llm/src/block_manager/block/transfer/memcpy.rs
+7
-5
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+22
-20
lib/llm/src/block_manager/config.rs
lib/llm/src/block_manager/config.rs
+3
-0
lib/llm/src/block_manager/layout.rs
lib/llm/src/block_manager/layout.rs
+144
-69
lib/llm/src/block_manager/layout/nixl.rs
lib/llm/src/block_manager/layout/nixl.rs
+1
-0
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+1
-0
lib/llm/src/block_manager/pool/inactive.rs
lib/llm/src/block_manager/pool/inactive.rs
+1
-0
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+1
-0
No files found.
lib/bindings/python/rust/llm/block_manager.rs
View file @
80256acf
...
...
@@ -46,10 +46,11 @@ pub struct BlockManager {
#[pymethods]
impl
BlockManager
{
#[new]
#[pyo3(signature
=
(worker_id,
num_layer,
page_size,
inner_dim,
dtype=None,
host_num_blocks=None,
device_num_blocks=None,
device_id=
0
))]
#[pyo3(signature
=
(worker_id,
num_layer,
outer_dim,
page_size,
inner_dim,
dtype=None,
host_num_blocks=None,
device_num_blocks=None,
device_id=
0
))]
fn
new
(
worker_id
:
u64
,
num_layer
:
usize
,
outer_dim
:
usize
,
page_size
:
usize
,
inner_dim
:
usize
,
dtype
:
Option
<
String
>
,
...
...
@@ -65,6 +66,7 @@ impl BlockManager {
);
let
mut
model_config
=
dynamo_llm
::
block_manager
::
KvManagerModelConfig
::
builder
()
.num_layers
(
num_layer
)
.outer_dim
(
outer_dim
)
.page_size
(
page_size
)
.inner_dim
(
inner_dim
);
let
mut
dtype_
=
dynamo_llm
::
common
::
dtype
::
DType
::
FP16
;
// Default in block_manager config
...
...
lib/bindings/python/tests/test_block_manager.py
View file @
80256acf
...
...
@@ -26,6 +26,7 @@ pytestmark = pytest.mark.pre_merge
WORKER_ID
=
0
NUM_LAYER
=
5
OUTER_DIM
=
2
PAGE_SIZE
=
4
INNER_DIM
=
13
DTYPE
,
TORCH_DTYPE
=
"FP32"
,
torch
.
float32
...
...
@@ -34,16 +35,35 @@ DEVICE_NUM_BLOCKS = 16
DEVICE_ID
=
0
@
pytest
.
fixture
def
block_manager
():
"""Pytest fixture for creating a BlockManager instance."""
return
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_manager_initialization
():
# Python should drop the BlockManager instance as soon as it goes out of scope, but
# it may not be garbage collected immediately, depending on the garbage collector.
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
...
...
@@ -52,6 +72,7 @@ async def test_block_manager_initialization():
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
...
...
@@ -61,6 +82,7 @@ async def test_block_manager_initialization():
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
...
...
@@ -70,6 +92,7 @@ async def test_block_manager_initialization():
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
...
...
@@ -80,17 +103,7 @@ async def test_block_manager_initialization():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_cpu_block_access
():
block_manager
=
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
async
def
test_cpu_block_access
(
block_manager
:
BlockManager
):
block_count
=
2
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
block_count
)
py_blocks
=
block_list
.
to_list
()
...
...
@@ -117,17 +130,7 @@ async def test_cpu_block_access():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_gpu_block_access
():
block_manager
=
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
async
def
test_gpu_block_access
(
block_manager
:
BlockManager
):
block_count
=
6
block_list
=
block_manager
.
allocate_device_blocks_blocking
(
block_count
)
py_blocks
=
block_list
.
to_list
()
...
...
@@ -154,17 +157,7 @@ async def test_gpu_block_access():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_list_iteration
():
block_manager
=
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
async
def
test_block_list_iteration
(
block_manager
:
BlockManager
):
block_count
=
4
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
block_count
)
# Test __len__()
...
...
@@ -192,17 +185,7 @@ async def test_block_list_iteration():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_copy_g1_g2
():
block_manager
=
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
async
def
test_block_copy_g1_g2
(
block_manager
:
BlockManager
):
# Allocate device (G1) and host (G2) block
host_block_list
=
block_manager
.
allocate_host_blocks_blocking
(
1
)
device_block_list
=
block_manager
.
allocate_device_blocks_blocking
(
1
)
...
...
@@ -243,10 +226,12 @@ async def test_block_copy_g1_g2():
async
def
main
():
await
test_block_manager_initialization
()
await
test_cpu_block_access
()
await
test_gpu_block_access
()
await
test_block_list_iteration
()
await
test_block_copy_g1_g2
()
# todo: revise these tests to index into the block via block_id, layer_id, outer_id (k/v)
# await test_cpu_block_access()
# await test_gpu_block_access()
# await test_block_list_iteration()
# await test_block_copy_g1_g2()
if
__name__
==
"__main__"
:
...
...
lib/llm/Cargo.toml
View file @
80256acf
...
...
@@ -27,6 +27,9 @@ description = "Dynamo LLM Library"
[features]
default
=
[]
# todo: enable this as default
# default = ["block-manager", "testing-full"]
testing-full
=
[
"testing-cuda"
,
"testing-nixl"
]
testing-cuda
=
["dep:cudarc"]
testing-nixl
=
["dep:nixl-sys"]
...
...
lib/llm/src/block_manager.rs
View file @
80256acf
...
...
@@ -203,6 +203,7 @@ mod tests {
.model
(
KvManagerModelConfig
::
builder
()
.num_layers
(
3
)
.outer_dim
(
2
)
.page_size
(
4
)
.inner_dim
(
16
)
.build
()
...
...
@@ -241,6 +242,8 @@ mod tests {
let
_
block_manager
=
create_reference_block_manager
();
}
// todo: solve the async runtime issue
#[ignore]
#[test]
fn
test_reference_block_manager_blocking
()
{
dynamo_runtime
::
logging
::
init
();
...
...
lib/llm/src/block_manager/block.rs
View file @
80256acf
...
...
@@ -393,11 +393,18 @@ pub trait BlockDataExt<S: Storage + NixlDescriptor> {
/// Returns the number of layers in the block
fn
num_layers
(
&
self
)
->
usize
;
/// Returns the number of outer dimensions in the block
fn
num_outer_dims
(
&
self
)
->
usize
;
/// Get a read-only view of this block's storage for a layer
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
;
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
;
/// Get a mutable view of this block's storage for a layer
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
;
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
;
/// Get a read-only view of this block's storage
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
;
...
...
@@ -451,21 +458,34 @@ where
self
.layout
.num_layers
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
layer_idx
)
?
;
unsafe
{
view
::
LayerView
::
new
(
self
,
offset
as
usize
,
self
.layout
.memory_region_size
())
}
fn
num_outer_dims
(
&
self
)
->
usize
{
self
.layout
.outer_dim
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
let
mr
=
self
.layout
.memory_region
(
self
.block_idx
,
layer_idx
,
outer_idx
)
?
;
unsafe
{
view
::
LayerView
::
new
(
self
,
mr
.addr
(),
mr
.size
())
}
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
layer_idx
)
?
;
unsafe
{
view
::
LayerViewMut
::
new
(
self
,
offset
as
usize
,
self
.layout
.memory_region_size
())
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
let
mr
=
self
.layout
.memory_region
(
self
.block_idx
,
layer_idx
,
outer_idx
)
?
;
unsafe
{
view
::
LayerViewMut
::
new
(
self
,
mr
.addr
(),
mr
.size
())
}
}
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
{
if
self
.is_fully_contiguous
()
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
0
)
?
;
let
size
=
self
.layout
.memory_region_size
()
*
self
.layout
.num_layers
();
unsafe
{
view
::
BlockView
::
new
(
self
,
offset
as
usize
,
size
)
}
let
mr
=
self
.layout
.memory_region
(
self
.block_idx
,
0
,
0
)
?
;
let
offset
=
mr
.addr
();
let
size
=
mr
.size
()
*
self
.num_layers
();
unsafe
{
view
::
BlockView
::
new
(
self
,
offset
,
size
)
}
}
else
{
Err
(
BlockError
::
InvalidState
(
"Block is not fully contiguous"
.to_string
(),
...
...
@@ -475,9 +495,10 @@ where
fn
block_view_mut
(
&
mut
self
)
->
BlockResult
<
view
::
BlockViewMut
<
S
>>
{
if
self
.is_fully_contiguous
()
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
0
)
?
;
let
size
=
self
.layout
.memory_region_size
()
*
self
.layout
.num_layers
();
unsafe
{
view
::
BlockViewMut
::
new
(
self
,
offset
as
usize
,
size
)
}
let
mr
=
self
.layout
.memory_region
(
self
.block_idx
,
0
,
0
)
?
;
let
offset
=
mr
.addr
();
let
size
=
mr
.size
()
*
self
.num_layers
();
unsafe
{
view
::
BlockViewMut
::
new
(
self
,
offset
,
size
)
}
}
else
{
Err
(
BlockError
::
InvalidState
(
"Block is not fully contiguous"
.to_string
(),
...
...
@@ -626,12 +647,20 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for MutableB
self
.data
.num_layers
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
self
.data
.layer_view
(
layer_idx
)
fn
num_outer_dims
(
&
self
)
->
usize
{
self
.data
.num_outer_dims
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
self
.data
.layer_view
(
layer_idx
,
outer_idx
)
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
self
.data
.layer_view_mut
(
layer_idx
)
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
self
.data
.layer_view_mut
(
layer_idx
,
outer_idx
)
}
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
{
...
...
@@ -755,11 +784,15 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for Immutabl
self
.block
.num_layers
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
self
.block
.
layer_view
(
layer_idx
)
fn
num_outer_dims
(
&
self
)
->
usize
{
self
.block
.
num_outer_dims
(
)
}
fn
layer_view_mut
(
&
mut
self
,
_
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
self
.block
.layer_view
(
layer_idx
,
outer_idx
)
}
fn
layer_view_mut
(
&
mut
self
,
_
:
usize
,
_
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
// This should never be called since ImmutableBlock is immutable,
// but we need to implement the full trait
Err
(
BlockError
::
InvalidState
(
...
...
@@ -946,6 +979,7 @@ pub mod nixl {
fn
as_layer_descriptor
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
;
}
...
...
@@ -961,6 +995,7 @@ pub mod nixl {
fn
as_layer_descriptor_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
;
}
...
...
@@ -974,8 +1009,9 @@ pub mod nixl {
fn
as_layer_descriptor
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
{
Ok
(
self
.layer_view
(
layer_idx
)
?
.as_nixl_descriptor
())
Ok
(
self
.layer_view
(
layer_idx
,
outer_idx
)
?
.as_nixl_descriptor
())
}
}
...
...
@@ -989,8 +1025,11 @@ pub mod nixl {
fn
as_layer_descriptor_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
{
Ok
(
self
.layer_view_mut
(
layer_idx
)
?
.as_nixl_descriptor_mut
())
Ok
(
self
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
.as_nixl_descriptor_mut
())
}
}
...
...
@@ -1188,15 +1227,24 @@ pub mod nixl {
self
.data
.num_layers
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
NixlStorage
>>
{
self
.data
.layer_view
(
layer_idx
)
fn
num_outer_dims
(
&
self
)
->
usize
{
self
.data
.num_outer_dims
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerView
<
NixlStorage
>>
{
self
.data
.layer_view
(
layer_idx
,
outer_idx
)
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
NixlStorage
>>
{
self
.data
.layer_view_mut
(
layer_idx
)
self
.data
.layer_view_mut
(
layer_idx
,
outer_idx
)
}
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
NixlStorage
>>
{
...
...
@@ -1224,8 +1272,9 @@ pub mod nixl {
fn
as_layer_descriptor
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
{
self
.data
.as_layer_descriptor
(
layer_idx
)
self
.data
.as_layer_descriptor
(
layer_idx
,
outer_idx
)
}
}
...
...
@@ -1244,8 +1293,9 @@ pub mod nixl {
fn
as_layer_descriptor_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
{
self
.data
.as_layer_descriptor_mut
(
layer_idx
)
self
.data
.as_layer_descriptor_mut
(
layer_idx
,
outer_idx
)
}
}
...
...
@@ -1733,7 +1783,8 @@ mod tests {
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_layers
(
2
)
.num_layers
(
3
)
.outer_dim
(
2
)
.page_size
(
4
)
.inner_dim
(
13
)
.build
()
...
...
@@ -1780,6 +1831,7 @@ mod tests {
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_layers
(
2
)
.outer_dim
(
1
)
.page_size
(
4
)
.inner_dim
(
13
)
.build
()
...
...
@@ -1803,12 +1855,12 @@ mod tests {
assert_eq!
(
mutable_block
.num_layers
(),
2
);
// Test layer_view()
let
layer_view
=
mutable_block
.layer_view
(
0
)
.unwrap
();
let
layer_view
=
mutable_block
.layer_view
(
0
,
0
)
.unwrap
();
assert_eq!
(
layer_view
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert
!
(
!
unsafe
{
layer_view
.as_ptr
()
}
.is_null
());
// Test layer_view_mut()
let
mut
layer_view_mut
=
mutable_block
.layer_view_mut
(
1
)
.unwrap
();
let
mut
layer_view_mut
=
mutable_block
.layer_view_mut
(
1
,
0
)
.unwrap
();
assert_eq!
(
layer_view_mut
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert
!
(
!
unsafe
{
layer_view_mut
.as_mut_ptr
()
}
.is_null
());
...
...
@@ -1833,6 +1885,7 @@ mod tests {
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_layers
(
2
)
.outer_dim
(
1
)
.page_size
(
4
)
.inner_dim
(
13
)
.build
()
...
...
@@ -1860,7 +1913,7 @@ mod tests {
assert_eq!
(
immutable_block
.num_layers
(),
2
);
// Test layer_view()
let
layer_view
=
immutable_block
.layer_view
(
0
)
.unwrap
();
let
layer_view
=
immutable_block
.layer_view
(
0
,
0
)
.unwrap
();
assert_eq!
(
layer_view
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert
!
(
!
unsafe
{
layer_view
.as_ptr
()
}
.is_null
());
...
...
@@ -1872,7 +1925,7 @@ mod tests {
// Test that mutable methods return errors
let
mut
mut_immutable_block
=
immutable_block
;
// We need a mutable reference for these tests
let
layer_view_mut_res
=
mut_immutable_block
.layer_view_mut
(
0
);
let
layer_view_mut_res
=
mut_immutable_block
.layer_view_mut
(
0
,
0
);
assert
!
(
layer_view_mut_res
.is_err
());
if
let
Err
(
BlockError
::
InvalidState
(
msg
))
=
layer_view_mut_res
{
assert
!
(
msg
.contains
(
"immutable block"
));
...
...
lib/llm/src/block_manager/block/transfer/cuda.rs
View file @
80256acf
...
...
@@ -112,18 +112,20 @@ where
}
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy_fn
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
(),
stream
,
)
?
;
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy_fn
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
(),
stream
,
)
?
;
}
}
}
Ok
(())
...
...
lib/llm/src/block_manager/block/transfer/memcpy.rs
View file @
80256acf
...
...
@@ -57,12 +57,14 @@ where
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
());
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
());
}
}
}
Ok
(())
...
...
lib/llm/src/block_manager/block/transfer/nixl.rs
View file @
80256acf
...
...
@@ -132,26 +132,28 @@ where
// }
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
let
src_desc
=
src_view
.as_nixl_descriptor
();
let
dst_desc
=
dst_view
.as_nixl_descriptor_mut
();
unsafe
{
src_dl
.add_desc
(
src_desc
.as_ptr
()
as
usize
,
src_desc
.size
(),
src_desc
.device_id
(),
)
?
;
dst_dl
.add_desc
(
dst_desc
.as_ptr
()
as
usize
,
dst_desc
.size
(),
dst_desc
.device_id
(),
)
?
;
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
let
src_desc
=
src_view
.as_nixl_descriptor
();
let
dst_desc
=
dst_view
.as_nixl_descriptor_mut
();
unsafe
{
src_dl
.add_desc
(
src_desc
.as_ptr
()
as
usize
,
src_desc
.size
(),
src_desc
.device_id
(),
)
?
;
dst_dl
.add_desc
(
dst_desc
.as_ptr
()
as
usize
,
dst_desc
.size
(),
dst_desc
.device_id
(),
)
?
;
}
}
}
...
...
lib/llm/src/block_manager/config.rs
View file @
80256acf
...
...
@@ -71,6 +71,9 @@ pub struct KvManagerModelConfig {
#[validate(range(min
=
1
))]
pub
num_layers
:
usize
,
#[validate(range(min
=
1
,
max
=
2
))]
pub
outer_dim
:
usize
,
#[validate(range(min
=
1
))]
pub
page_size
:
usize
,
...
...
lib/llm/src/block_manager/layout.rs
View file @
80256acf
...
...
@@ -84,6 +84,7 @@
//! let config = LayoutConfig::builder()
//! .num_blocks(10)
//! .num_layers(4)
//! .outer_dim(1)
//! .page_size(16)
//! .inner_dim(128)
//! .dtype(DType::FP16)
...
...
@@ -109,8 +110,12 @@
//! which extends these layout concepts for NIXL (NVIDIA Interface eXchange Layer), enabling
//! layouts to be registered and serialized for use in distributed environments.
// todo: coming soon...
// pub mod distributed;
pub
mod
nixl
;
use
derive_getters
::
Getters
;
use
thiserror
::
Error
;
use
crate
::
block_manager
::
storage
::{
Storage
,
StorageAllocator
};
...
...
@@ -138,6 +143,9 @@ pub enum LayoutError {
#[error(
"Invalid layer index: {0}"
)]
InvalidLayerIndex
(
usize
),
#[error(
"Invalid outer index: {0}"
)]
InvalidOuterIndex
(
usize
),
#[error(
"Operation failed: {0}"
)]
OperationFailed
(
String
),
...
...
@@ -165,10 +173,18 @@ pub enum LayoutType {
// Null,
}
/// Local Memory Region
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq,
Serialize,
Deserialize,
Getters)]
pub
struct
LocalMemoryRegion
{
#[getter(copy)]
addr
:
usize
,
#[getter(copy)]
size
:
usize
,
}
/// Core trait for block layouts
pub
trait
BlockLayout
:
BlockLayoutConfig
+
BlockLayoutLookup
+
Send
+
Sync
+
std
::
fmt
::
Debug
{
pub
trait
BlockLayout
:
BlockLayoutConfig
+
Send
+
Sync
+
std
::
fmt
::
Debug
{
/// The type of storage this layout uses
type
StorageType
:
Storage
;
...
...
@@ -180,6 +196,21 @@ pub trait BlockLayout:
/// Storage type for the layout
fn
storage_type
(
&
self
)
->
StorageType
;
/// Get the memory region for a specific page [page_size, inner_dim]
///
/// # Arguments
///
/// * `block_idx` - The index of the block
/// * `layer_idx` - The index of the layer
/// * `outer_idx` - The index of the outer dimension, e.g. if
///
fn
memory_region
(
&
self
,
block_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
Result
<
LocalMemoryRegion
,
LayoutError
>
;
}
/// Configuration for block layouts
...
...
@@ -193,6 +224,12 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
/// Returns the number of layers per block
fn
num_layers
(
&
self
)
->
usize
;
/// Returns the number of outer dimensions per block
/// In some cases, K and V might be indexed separately, so in that example one might have 2 outer dimensions
/// For MLA, this is 1.
/// The location of the outer dimension in the shape of the tensor layout is defined by the layout type.
fn
outer_dim
(
&
self
)
->
usize
;
/// Returns the size of each block in bytes
fn
page_size
(
&
self
)
->
usize
;
...
...
@@ -200,15 +237,6 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
fn
inner_dim
(
&
self
)
->
usize
;
}
/// Trait for looking up memory regions in a block layout
pub
trait
BlockLayoutLookup
{
/// Get the memory region for a specific page [page_size, inner_dim]
fn
memory_region_addr
(
&
self
,
block_idx
:
usize
,
layer_idx
:
usize
)
->
Result
<
u64
,
LayoutError
>
;
/// Get the memory region for a specific page [page_size, inner_dim]
fn
memory_region_size
(
&
self
)
->
usize
;
}
/// Configuration for block layouts
#[derive(Debug,
Clone,
Builder,
Validate,
Serialize,
Deserialize)]
pub
struct
LayoutConfig
{
...
...
@@ -220,6 +248,10 @@ pub struct LayoutConfig {
#[validate(range(min
=
1
))]
pub
num_layers
:
usize
,
/// Number of outer dimensions
#[validate(range(min
=
1
,
max
=
2
))]
pub
outer_dim
:
usize
,
/// Page size
#[validate(range(min
=
1
))]
pub
page_size
:
usize
,
...
...
@@ -268,11 +300,25 @@ fn align_up(value: usize, alignment: usize) -> usize {
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
(
crate
)
struct
FullyContiguousConfig
{
inner
:
LayoutConfig
,
/// Minimum contiguous memory region size
/// Inner dimension * page size * dtype size
memory_region_size
:
usize
,
/// Stride between outer dimensions
outer_dim_stride_in_bytes
:
usize
,
/// Stride between layers
layer_stride_in_bytes
:
usize
,
/// Natural block stride
natural_block_stride
:
usize
,
/// Block stride in bytes
block_stride_in_bytes
:
usize
,
// Aligned if necessary
layout_data_bytes
:
usize
,
// Size of the layout data itself (post base offset)
/// Size of the layout data itself (post base offset)
layout_data_bytes
:
usize
,
// Size of the layout data itself (post base offset)
}
impl
FullyContiguousConfig
{
...
...
@@ -284,7 +330,8 @@ impl FullyContiguousConfig {
let
alignment
=
config
.alignment
;
let
memory_region_size
=
config
.page_size
*
config
.inner_dim
*
config
.dtype
.size_in_bytes
();
let
layer_stride_in_bytes
=
memory_region_size
;
let
outer_dim_stride_in_bytes
=
memory_region_size
;
let
layer_stride_in_bytes
=
outer_dim_stride_in_bytes
*
config
.outer_dim
;
let
natural_block_stride
=
config
.num_layers
*
layer_stride_in_bytes
;
let
block_stride_in_bytes
=
if
alignment
>
1
{
...
...
@@ -299,6 +346,7 @@ impl FullyContiguousConfig {
Ok
(
Self
{
inner
:
config
,
memory_region_size
,
outer_dim_stride_in_bytes
,
layer_stride_in_bytes
,
natural_block_stride
,
block_stride_in_bytes
,
...
...
@@ -331,6 +379,10 @@ impl BlockLayoutConfig for FullyContiguousConfig {
self
.inner.num_layers
}
fn
outer_dim
(
&
self
)
->
usize
{
self
.inner.outer_dim
}
fn
page_size
(
&
self
)
->
usize
{
self
.inner.page_size
}
...
...
@@ -508,6 +560,39 @@ impl<S: Storage> BlockLayout for FullyContiguous<S> {
fn
storage_type
(
&
self
)
->
StorageType
{
self
.storage_type
.clone
()
}
fn
memory_region
(
&
self
,
block_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
Result
<
LocalMemoryRegion
,
LayoutError
>
{
if
block_idx
>=
self
.num_blocks
()
{
return
Err
(
LayoutError
::
InvalidBlockIndex
(
block_idx
));
}
if
layer_idx
>=
self
.num_layers
()
{
return
Err
(
LayoutError
::
InvalidLayerIndex
(
layer_idx
));
}
if
outer_idx
>=
self
.outer_dim
()
{
return
Err
(
LayoutError
::
InvalidOuterIndex
(
outer_idx
));
}
// Start from the aligned base address
let
aligned_start_addr
=
self
.storage
.addr
()
as
usize
+
self
.base_offset
;
// Calculate offset relative to the aligned start using stored config
let
block_offset
=
block_idx
*
self
.config.block_stride_in_bytes
;
let
layer_offset
=
layer_idx
*
self
.config.layer_stride_in_bytes
;
let
outer_offset
=
outer_idx
*
self
.config.outer_dim_stride_in_bytes
;
let
final_addr
=
aligned_start_addr
+
block_offset
+
layer_offset
+
outer_offset
;
Ok
(
LocalMemoryRegion
{
addr
:
final_addr
,
size
:
self
.config.memory_region_size
,
})
}
}
impl
<
S
:
Storage
>
BlockLayoutConfig
for
FullyContiguous
<
S
>
{
...
...
@@ -523,6 +608,10 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
self
.config.inner.num_layers
}
fn
outer_dim
(
&
self
)
->
usize
{
self
.config.inner.outer_dim
}
fn
page_size
(
&
self
)
->
usize
{
self
.config.inner.page_size
}
...
...
@@ -532,33 +621,6 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
}
}
impl
<
S
:
Storage
>
BlockLayoutLookup
for
FullyContiguous
<
S
>
{
fn
memory_region_addr
(
&
self
,
block_idx
:
usize
,
layer_idx
:
usize
)
->
Result
<
u64
,
LayoutError
>
{
if
block_idx
>=
self
.num_blocks
()
{
return
Err
(
LayoutError
::
InvalidBlockIndex
(
block_idx
));
}
if
layer_idx
>=
self
.num_layers
()
{
return
Err
(
LayoutError
::
InvalidLayerIndex
(
layer_idx
));
}
// Start from the aligned base address
let
aligned_start_addr
=
self
.storage
.addr
()
+
self
.base_offset
as
u64
;
// Calculate offset relative to the aligned start using stored config
let
block_offset
=
block_idx
*
self
.config.block_stride_in_bytes
;
let
layer_offset
=
layer_idx
*
self
.config.layer_stride_in_bytes
;
let
final_addr
=
aligned_start_addr
+
block_offset
as
u64
+
layer_offset
as
u64
;
Ok
(
final_addr
)
}
fn
memory_region_size
(
&
self
)
->
usize
{
// Access via stored dims
self
.config.memory_region_size
}
}
#[allow(missing_docs)]
#[cfg(test)]
pub
mod
tests
{
...
...
@@ -570,6 +632,7 @@ pub mod tests {
const
NUM_BLOCKS
:
usize
=
7
;
const
NUM_LAYERS
:
usize
=
5
;
const
OUTER_DIM
:
usize
=
2
;
const
PAGE_SIZE
:
usize
=
4
;
const
INNER_DIM
:
usize
=
13
;
const
DTYPE
:
DType
=
DType
::
FP32
;
// Example dtype
...
...
@@ -592,6 +655,7 @@ pub mod tests {
let
config
=
LayoutConfig
{
num_blocks
:
NUM_BLOCKS
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
OUTER_DIM
,
page_size
:
PAGE_SIZE
,
inner_dim
:
INNER_DIM
,
alignment
:
alignment
.unwrap_or
(
1
),
...
...
@@ -606,6 +670,7 @@ pub mod tests {
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
NUM_BLOCKS
)
.num_layers
(
NUM_LAYERS
)
.outer_dim
(
OUTER_DIM
)
.page_size
(
PAGE_SIZE
)
.inner_dim
(
INNER_DIM
)
.alignment
(
3
)
...
...
@@ -632,6 +697,7 @@ pub mod tests {
let
config
=
LayoutConfig
{
num_blocks
:
NUM_BLOCKS
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
OUTER_DIM
,
page_size
:
PAGE_SIZE
,
inner_dim
:
INNER_DIM
,
alignment
:
1
,
...
...
@@ -656,17 +722,11 @@ pub mod tests {
assert_eq!
(
layout
.num_blocks
(),
NUM_BLOCKS
);
assert_eq!
(
layout
.num_layers
(),
NUM_LAYERS
);
assert_eq!
(
layout
.outer_dim
(),
OUTER_DIM
);
assert_eq!
(
layout
.page_size
(),
PAGE_SIZE
);
assert_eq!
(
layout
.inner_dim
(),
INNER_DIM
);
}
#[test]
fn
test_fc_memory_region_size
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
expected_region_size
=
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
();
assert_eq!
(
layout
.memory_region_size
(),
expected_region_size
);
}
#[test]
fn
test_fc_offset_calculation
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
...
...
@@ -680,7 +740,7 @@ pub mod tests {
let
expected_offset_0_0
=
calculate_expected_offset
(
base_addr
,
0
,
0
,
block_stride
,
layer_stride
);
assert_eq!
(
layout
.memory_region
_addr
(
0
,
0
)
.unwrap
(),
layout
.memory_region
(
0
,
0
,
0
)
.unwrap
()
.addr
as
u64
,
expected_offset_0_0
);
...
...
@@ -689,7 +749,7 @@ pub mod tests {
let
expected_offset_0_last
=
calculate_expected_offset
(
base_addr
,
0
,
last_layer_idx
,
block_stride
,
layer_stride
);
assert_eq!
(
layout
.memory_region
_addr
(
0
,
last_layer_idx
)
.unwrap
(),
layout
.memory_region
(
0
,
last_layer_idx
,
0
)
.unwrap
()
.addr
as
u64
,
expected_offset_0_last
);
...
...
@@ -698,7 +758,7 @@ pub mod tests {
let
expected_offset_last_0
=
calculate_expected_offset
(
base_addr
,
last_block_idx
,
0
,
block_stride
,
layer_stride
);
assert_eq!
(
layout
.memory_region
_addr
(
last_block_idx
,
0
)
.unwrap
(),
layout
.memory_region
(
last_block_idx
,
0
,
0
)
.unwrap
()
.addr
as
u64
,
expected_offset_last_0
);
...
...
@@ -712,8 +772,9 @@ pub mod tests {
);
assert_eq!
(
layout
.memory_region_addr
(
last_block_idx
,
last_layer_idx
)
.unwrap
(),
.memory_region
(
last_block_idx
,
last_layer_idx
,
0
)
.unwrap
()
.addr
as
u64
,
expected_offset_last_last
);
...
...
@@ -729,8 +790,9 @@ pub mod tests {
);
assert_eq!
(
layout
.memory_region_addr
(
mid_block_idx
,
mid_layer_idx
)
.unwrap
(),
.memory_region
(
mid_block_idx
,
mid_layer_idx
,
0
)
.unwrap
()
.addr
as
u64
,
expected_offset_mid_mid
);
}
...
...
@@ -738,7 +800,7 @@ pub mod tests {
#[test]
fn
test_fc_invalid_block_index
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
result
=
layout
.memory_region
_addr
(
NUM_BLOCKS
,
0
);
// Index == num_blocks (out of bounds)
let
result
=
layout
.memory_region
(
NUM_BLOCKS
,
0
,
0
);
// Index == num_blocks (out of bounds)
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.err
()
.unwrap
(),
...
...
@@ -749,7 +811,7 @@ pub mod tests {
#[test]
fn
test_fc_invalid_layer_index
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
result
=
layout
.memory_region
_addr
(
0
,
NUM_LAYERS
);
// Index == num_layers (out of bounds)
let
result
=
layout
.memory_region
(
0
,
NUM_LAYERS
,
0
);
// Index == num_layers (out of bounds)
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.err
()
.unwrap
(),
...
...
@@ -757,12 +819,24 @@ pub mod tests {
));
}
#[test]
fn
test_fc_invalid_outer_index
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
result
=
layout
.memory_region
(
0
,
0
,
OUTER_DIM
);
// Index == num_outer_dims (out of bounds)
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.err
()
.unwrap
(),
LayoutError
::
InvalidOuterIndex
(
OUTER_DIM
)
));
}
#[test]
fn
test_fc_allocation_system
()
{
init_logging
();
let
config
=
LayoutConfig
{
num_blocks
:
NUM_BLOCKS
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
OUTER_DIM
,
page_size
:
PAGE_SIZE
,
inner_dim
:
INNER_DIM
,
alignment
:
1
,
...
...
@@ -788,7 +862,7 @@ pub mod tests {
assert_eq!
(
layout
.storage
.size
(),
NUM_BLOCKS
*
NUM_LAYERS
*
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
()
NUM_BLOCKS
*
NUM_LAYERS
*
OUTER_DIM
*
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
()
);
}
...
...
@@ -800,6 +874,7 @@ pub mod tests {
let
config
=
LayoutConfig
{
num_blocks
:
NUM_BLOCKS
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
OUTER_DIM
,
page_size
:
PAGE_SIZE
,
inner_dim
:
INNER_DIM
,
alignment
:
ALIGNMENT
,
...
...
@@ -810,11 +885,11 @@ pub mod tests {
let
memory_region_size
=
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
();
assert_eq!
(
memory_region_size
,
208
);
let
natural_block_stride
=
NUM_LAYERS
*
memory_region_size
;
assert_eq!
(
natural_block_stride
,
104
0
);
let
natural_block_stride
=
OUTER_DIM
*
NUM_LAYERS
*
memory_region_size
;
assert_eq!
(
natural_block_stride
,
208
0
);
let
aligned_block_stride
=
align_up
(
natural_block_stride
,
ALIGNMENT
);
assert_eq!
(
aligned_block_stride
,
1280
);
assert_eq!
(
aligned_block_stride
,
2304
);
// Calculate the expected *allocated* size (data + initial padding)
let
fc_config
=
FullyContiguousConfig
::
new
(
config
.clone
())
.unwrap
();
...
...
@@ -844,40 +919,40 @@ pub mod tests {
// Check alignment of block starts
let
addr_block_0
=
layout
.memory_region
_addr
(
0
,
0
)
.memory_region
(
0
,
0
,
0
)
.expect
(
"Failed to get addr block 0"
);
let
addr_block_1
=
layout
.memory_region
_addr
(
1
,
0
)
.memory_region
(
1
,
0
,
0
)
.expect
(
"Failed to get addr block 1"
);
let
addr_block_2
=
layout
.memory_region
_addr
(
2
,
0
)
.memory_region
(
2
,
0
,
0
)
.expect
(
"Failed to get addr block 2"
);
// All blocks should now be aligned due to base_offset adjustment
assert_eq!
(
addr_block_0
%
ALIGNMENT
as
u64
,
addr_block_0
.addr
as
u64
%
ALIGNMENT
as
u64
,
0
,
"Block 0 start address is not aligned"
);
assert_eq!
(
addr_block_1
%
ALIGNMENT
as
u64
,
addr_block_1
.addr
as
u64
%
ALIGNMENT
as
u64
,
0
,
"Block 1 start address is not aligned"
);
assert_eq!
(
addr_block_2
%
ALIGNMENT
as
u64
,
addr_block_2
.addr
as
u64
%
ALIGNMENT
as
u64
,
0
,
"Block 2 start address is not aligned"
);
// Verify the difference matches the aligned stride
assert_eq!
(
addr_block_1
-
addr_block_0
,
addr_block_1
.addr
as
u64
-
addr_block_0
.addr
as
u64
,
aligned_block_stride
as
u64
,
"Stride between block 0 and 1 mismatch"
);
assert_eq!
(
addr_block_2
-
addr_block_1
,
addr_block_2
.addr
as
u64
-
addr_block_1
.addr
as
u64
,
aligned_block_stride
as
u64
,
"Stride between block 1 and 2 mismatch"
);
...
...
lib/llm/src/block_manager/layout/nixl.rs
View file @
80256acf
...
...
@@ -332,6 +332,7 @@ mod tests {
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_layers
(
2
)
.outer_dim
(
2
)
.page_size
(
4
)
.inner_dim
(
13
)
.build
()
...
...
lib/llm/src/block_manager/offload.rs
View file @
80256acf
...
...
@@ -506,6 +506,7 @@ mod tests {
let
mut
config
=
LayoutConfig
{
num_blocks
:
device_blocks
,
num_layers
:
8
,
outer_dim
:
1
,
page_size
:
BLOCK_SIZE
,
inner_dim
:
1024
,
alignment
:
1
,
...
...
lib/llm/src/block_manager/pool/inactive.rs
View file @
80256acf
...
...
@@ -582,6 +582,7 @@ pub(crate) mod tests {
let
config
=
LayoutConfigBuilder
::
default
()
.num_blocks
(
num_blocks
)
.num_layers
(
61
)
.outer_dim
(
1
)
.page_size
(
16
)
.inner_dim
(
576
)
.build
()
...
...
lib/llm/src/block_manager/state.rs
View file @
80256acf
...
...
@@ -115,6 +115,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
layout_builder
.num_layers
(
model
.num_layers
)
.outer_dim
(
model
.outer_dim
)
.page_size
(
model
.page_size
)
.inner_dim
(
model
.inner_dim
)
.dtype
(
model
.dtype
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment