Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
80256acf
"lib/vscode:/vscode.git/clone" did not exist on "1208f017f96060cdab69668d05605bb50ad8ede7"
Unverified
Commit
80256acf
authored
May 20, 2025
by
Ryan Olson
Committed by
GitHub
May 20, 2025
Browse files
feat: adding outer dimension to isolate k/v blocks (#1126)
parent
7e452a2e
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
325 additions
and
191 deletions
+325
-191
lib/bindings/python/rust/llm/block_manager.rs
lib/bindings/python/rust/llm/block_manager.rs
+3
-1
lib/bindings/python/tests/test_block_manager.py
lib/bindings/python/tests/test_block_manager.py
+36
-51
lib/llm/Cargo.toml
lib/llm/Cargo.toml
+3
-0
lib/llm/src/block_manager.rs
lib/llm/src/block_manager.rs
+3
-0
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+86
-33
lib/llm/src/block_manager/block/transfer/cuda.rs
lib/llm/src/block_manager/block/transfer/cuda.rs
+14
-12
lib/llm/src/block_manager/block/transfer/memcpy.rs
lib/llm/src/block_manager/block/transfer/memcpy.rs
+7
-5
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+22
-20
lib/llm/src/block_manager/config.rs
lib/llm/src/block_manager/config.rs
+3
-0
lib/llm/src/block_manager/layout.rs
lib/llm/src/block_manager/layout.rs
+144
-69
lib/llm/src/block_manager/layout/nixl.rs
lib/llm/src/block_manager/layout/nixl.rs
+1
-0
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+1
-0
lib/llm/src/block_manager/pool/inactive.rs
lib/llm/src/block_manager/pool/inactive.rs
+1
-0
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+1
-0
No files found.
lib/bindings/python/rust/llm/block_manager.rs
View file @
80256acf
...
@@ -46,10 +46,11 @@ pub struct BlockManager {
...
@@ -46,10 +46,11 @@ pub struct BlockManager {
#[pymethods]
#[pymethods]
impl
BlockManager
{
impl
BlockManager
{
#[new]
#[new]
#[pyo3(signature
=
(worker_id,
num_layer,
page_size,
inner_dim,
dtype=None,
host_num_blocks=None,
device_num_blocks=None,
device_id=
0
))]
#[pyo3(signature
=
(worker_id,
num_layer,
outer_dim,
page_size,
inner_dim,
dtype=None,
host_num_blocks=None,
device_num_blocks=None,
device_id=
0
))]
fn
new
(
fn
new
(
worker_id
:
u64
,
worker_id
:
u64
,
num_layer
:
usize
,
num_layer
:
usize
,
outer_dim
:
usize
,
page_size
:
usize
,
page_size
:
usize
,
inner_dim
:
usize
,
inner_dim
:
usize
,
dtype
:
Option
<
String
>
,
dtype
:
Option
<
String
>
,
...
@@ -65,6 +66,7 @@ impl BlockManager {
...
@@ -65,6 +66,7 @@ impl BlockManager {
);
);
let
mut
model_config
=
dynamo_llm
::
block_manager
::
KvManagerModelConfig
::
builder
()
let
mut
model_config
=
dynamo_llm
::
block_manager
::
KvManagerModelConfig
::
builder
()
.num_layers
(
num_layer
)
.num_layers
(
num_layer
)
.outer_dim
(
outer_dim
)
.page_size
(
page_size
)
.page_size
(
page_size
)
.inner_dim
(
inner_dim
);
.inner_dim
(
inner_dim
);
let
mut
dtype_
=
dynamo_llm
::
common
::
dtype
::
DType
::
FP16
;
// Default in block_manager config
let
mut
dtype_
=
dynamo_llm
::
common
::
dtype
::
DType
::
FP16
;
// Default in block_manager config
...
...
lib/bindings/python/tests/test_block_manager.py
View file @
80256acf
...
@@ -26,6 +26,7 @@ pytestmark = pytest.mark.pre_merge
...
@@ -26,6 +26,7 @@ pytestmark = pytest.mark.pre_merge
WORKER_ID
=
0
WORKER_ID
=
0
NUM_LAYER
=
5
NUM_LAYER
=
5
OUTER_DIM
=
2
PAGE_SIZE
=
4
PAGE_SIZE
=
4
INNER_DIM
=
13
INNER_DIM
=
13
DTYPE
,
TORCH_DTYPE
=
"FP32"
,
torch
.
float32
DTYPE
,
TORCH_DTYPE
=
"FP32"
,
torch
.
float32
...
@@ -34,16 +35,35 @@ DEVICE_NUM_BLOCKS = 16
...
@@ -34,16 +35,35 @@ DEVICE_NUM_BLOCKS = 16
DEVICE_ID
=
0
DEVICE_ID
=
0
@
pytest
.
fixture
def
block_manager
():
"""Pytest fixture for creating a BlockManager instance."""
return
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_manager_initialization
():
async
def
test_block_manager_initialization
():
# Python should drop the BlockManager instance as soon as it goes out of scope, but
# Python should drop the BlockManager instance as soon as it goes out of scope, but
# it may not be garbage collected immediately, depending on the garbage collector.
# it may not be garbage collected immediately, depending on the garbage collector.
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
)
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
)
BlockManager
(
BlockManager
(
WORKER_ID
,
WORKER_ID
,
NUM_LAYER
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
PAGE_SIZE
,
INNER_DIM
,
INNER_DIM
,
DTYPE
,
DTYPE
,
...
@@ -52,6 +72,7 @@ async def test_block_manager_initialization():
...
@@ -52,6 +72,7 @@ async def test_block_manager_initialization():
BlockManager
(
BlockManager
(
WORKER_ID
,
WORKER_ID
,
NUM_LAYER
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
PAGE_SIZE
,
INNER_DIM
,
INNER_DIM
,
DTYPE
,
DTYPE
,
...
@@ -61,6 +82,7 @@ async def test_block_manager_initialization():
...
@@ -61,6 +82,7 @@ async def test_block_manager_initialization():
BlockManager
(
BlockManager
(
WORKER_ID
,
WORKER_ID
,
NUM_LAYER
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
PAGE_SIZE
,
INNER_DIM
,
INNER_DIM
,
DTYPE
,
DTYPE
,
...
@@ -70,6 +92,7 @@ async def test_block_manager_initialization():
...
@@ -70,6 +92,7 @@ async def test_block_manager_initialization():
BlockManager
(
BlockManager
(
WORKER_ID
,
WORKER_ID
,
NUM_LAYER
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
PAGE_SIZE
,
INNER_DIM
,
INNER_DIM
,
DTYPE
,
DTYPE
,
...
@@ -80,17 +103,7 @@ async def test_block_manager_initialization():
...
@@ -80,17 +103,7 @@ async def test_block_manager_initialization():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_cpu_block_access
():
async
def
test_cpu_block_access
(
block_manager
:
BlockManager
):
block_manager
=
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
block_count
=
2
block_count
=
2
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
block_count
)
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
block_count
)
py_blocks
=
block_list
.
to_list
()
py_blocks
=
block_list
.
to_list
()
...
@@ -117,17 +130,7 @@ async def test_cpu_block_access():
...
@@ -117,17 +130,7 @@ async def test_cpu_block_access():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_gpu_block_access
():
async
def
test_gpu_block_access
(
block_manager
:
BlockManager
):
block_manager
=
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
block_count
=
6
block_count
=
6
block_list
=
block_manager
.
allocate_device_blocks_blocking
(
block_count
)
block_list
=
block_manager
.
allocate_device_blocks_blocking
(
block_count
)
py_blocks
=
block_list
.
to_list
()
py_blocks
=
block_list
.
to_list
()
...
@@ -154,17 +157,7 @@ async def test_gpu_block_access():
...
@@ -154,17 +157,7 @@ async def test_gpu_block_access():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_list_iteration
():
async
def
test_block_list_iteration
(
block_manager
:
BlockManager
):
block_manager
=
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
block_count
=
4
block_count
=
4
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
block_count
)
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
block_count
)
# Test __len__()
# Test __len__()
...
@@ -192,17 +185,7 @@ async def test_block_list_iteration():
...
@@ -192,17 +185,7 @@ async def test_block_list_iteration():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_copy_g1_g2
():
async
def
test_block_copy_g1_g2
(
block_manager
:
BlockManager
):
block_manager
=
BlockManager
(
WORKER_ID
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
,
DTYPE
,
HOST_NUM_BLOCKS
,
DEVICE_NUM_BLOCKS
,
DEVICE_ID
,
)
# Allocate device (G1) and host (G2) block
# Allocate device (G1) and host (G2) block
host_block_list
=
block_manager
.
allocate_host_blocks_blocking
(
1
)
host_block_list
=
block_manager
.
allocate_host_blocks_blocking
(
1
)
device_block_list
=
block_manager
.
allocate_device_blocks_blocking
(
1
)
device_block_list
=
block_manager
.
allocate_device_blocks_blocking
(
1
)
...
@@ -243,10 +226,12 @@ async def test_block_copy_g1_g2():
...
@@ -243,10 +226,12 @@ async def test_block_copy_g1_g2():
async
def
main
():
async
def
main
():
await
test_block_manager_initialization
()
await
test_block_manager_initialization
()
await
test_cpu_block_access
()
await
test_gpu_block_access
()
# todo: revise these tests to index into the block via block_id, layer_id, outer_id (k/v)
await
test_block_list_iteration
()
# await test_cpu_block_access()
await
test_block_copy_g1_g2
()
# await test_gpu_block_access()
# await test_block_list_iteration()
# await test_block_copy_g1_g2()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
lib/llm/Cargo.toml
View file @
80256acf
...
@@ -27,6 +27,9 @@ description = "Dynamo LLM Library"
...
@@ -27,6 +27,9 @@ description = "Dynamo LLM Library"
[features]
[features]
default
=
[]
default
=
[]
# todo: enable this as default
# default = ["block-manager", "testing-full"]
testing-full
=
[
"testing-cuda"
,
"testing-nixl"
]
testing-full
=
[
"testing-cuda"
,
"testing-nixl"
]
testing-cuda
=
["dep:cudarc"]
testing-cuda
=
["dep:cudarc"]
testing-nixl
=
["dep:nixl-sys"]
testing-nixl
=
["dep:nixl-sys"]
...
...
lib/llm/src/block_manager.rs
View file @
80256acf
...
@@ -203,6 +203,7 @@ mod tests {
...
@@ -203,6 +203,7 @@ mod tests {
.model
(
.model
(
KvManagerModelConfig
::
builder
()
KvManagerModelConfig
::
builder
()
.num_layers
(
3
)
.num_layers
(
3
)
.outer_dim
(
2
)
.page_size
(
4
)
.page_size
(
4
)
.inner_dim
(
16
)
.inner_dim
(
16
)
.build
()
.build
()
...
@@ -241,6 +242,8 @@ mod tests {
...
@@ -241,6 +242,8 @@ mod tests {
let
_
block_manager
=
create_reference_block_manager
();
let
_
block_manager
=
create_reference_block_manager
();
}
}
// todo: solve the async runtime issue
#[ignore]
#[test]
#[test]
fn
test_reference_block_manager_blocking
()
{
fn
test_reference_block_manager_blocking
()
{
dynamo_runtime
::
logging
::
init
();
dynamo_runtime
::
logging
::
init
();
...
...
lib/llm/src/block_manager/block.rs
View file @
80256acf
...
@@ -393,11 +393,18 @@ pub trait BlockDataExt<S: Storage + NixlDescriptor> {
...
@@ -393,11 +393,18 @@ pub trait BlockDataExt<S: Storage + NixlDescriptor> {
/// Returns the number of layers in the block
/// Returns the number of layers in the block
fn
num_layers
(
&
self
)
->
usize
;
fn
num_layers
(
&
self
)
->
usize
;
/// Returns the number of outer dimensions in the block
fn
num_outer_dims
(
&
self
)
->
usize
;
/// Get a read-only view of this block's storage for a layer
/// Get a read-only view of this block's storage for a layer
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
;
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
;
/// Get a mutable view of this block's storage for a layer
/// Get a mutable view of this block's storage for a layer
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
;
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
;
/// Get a read-only view of this block's storage
/// Get a read-only view of this block's storage
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
;
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
;
...
@@ -451,21 +458,34 @@ where
...
@@ -451,21 +458,34 @@ where
self
.layout
.num_layers
()
self
.layout
.num_layers
()
}
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
fn
num_outer_dims
(
&
self
)
->
usize
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
layer_idx
)
?
;
self
.layout
.outer_dim
()
unsafe
{
view
::
LayerView
::
new
(
self
,
offset
as
usize
,
self
.layout
.memory_region_size
())
}
}
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
layer_idx
)
?
;
let
mr
=
self
unsafe
{
view
::
LayerViewMut
::
new
(
self
,
offset
as
usize
,
self
.layout
.memory_region_size
())
}
.layout
.memory_region
(
self
.block_idx
,
layer_idx
,
outer_idx
)
?
;
unsafe
{
view
::
LayerView
::
new
(
self
,
mr
.addr
(),
mr
.size
())
}
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
let
mr
=
self
.layout
.memory_region
(
self
.block_idx
,
layer_idx
,
outer_idx
)
?
;
unsafe
{
view
::
LayerViewMut
::
new
(
self
,
mr
.addr
(),
mr
.size
())
}
}
}
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
{
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
{
if
self
.is_fully_contiguous
()
{
if
self
.is_fully_contiguous
()
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
0
)
?
;
let
mr
=
self
.layout
.memory_region
(
self
.block_idx
,
0
,
0
)
?
;
let
size
=
self
.layout
.memory_region_size
()
*
self
.layout
.num_layers
();
let
offset
=
mr
.addr
();
unsafe
{
view
::
BlockView
::
new
(
self
,
offset
as
usize
,
size
)
}
let
size
=
mr
.size
()
*
self
.num_layers
();
unsafe
{
view
::
BlockView
::
new
(
self
,
offset
,
size
)
}
}
else
{
}
else
{
Err
(
BlockError
::
InvalidState
(
Err
(
BlockError
::
InvalidState
(
"Block is not fully contiguous"
.to_string
(),
"Block is not fully contiguous"
.to_string
(),
...
@@ -475,9 +495,10 @@ where
...
@@ -475,9 +495,10 @@ where
fn
block_view_mut
(
&
mut
self
)
->
BlockResult
<
view
::
BlockViewMut
<
S
>>
{
fn
block_view_mut
(
&
mut
self
)
->
BlockResult
<
view
::
BlockViewMut
<
S
>>
{
if
self
.is_fully_contiguous
()
{
if
self
.is_fully_contiguous
()
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
0
)
?
;
let
mr
=
self
.layout
.memory_region
(
self
.block_idx
,
0
,
0
)
?
;
let
size
=
self
.layout
.memory_region_size
()
*
self
.layout
.num_layers
();
let
offset
=
mr
.addr
();
unsafe
{
view
::
BlockViewMut
::
new
(
self
,
offset
as
usize
,
size
)
}
let
size
=
mr
.size
()
*
self
.num_layers
();
unsafe
{
view
::
BlockViewMut
::
new
(
self
,
offset
,
size
)
}
}
else
{
}
else
{
Err
(
BlockError
::
InvalidState
(
Err
(
BlockError
::
InvalidState
(
"Block is not fully contiguous"
.to_string
(),
"Block is not fully contiguous"
.to_string
(),
...
@@ -626,12 +647,20 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for MutableB
...
@@ -626,12 +647,20 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for MutableB
self
.data
.num_layers
()
self
.data
.num_layers
()
}
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
fn
num_outer_dims
(
&
self
)
->
usize
{
self
.data
.layer_view
(
layer_idx
)
self
.data
.num_outer_dims
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
self
.data
.layer_view
(
layer_idx
,
outer_idx
)
}
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
fn
layer_view_mut
(
self
.data
.layer_view_mut
(
layer_idx
)
&
mut
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
self
.data
.layer_view_mut
(
layer_idx
,
outer_idx
)
}
}
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
{
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
{
...
@@ -755,11 +784,15 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for Immutabl
...
@@ -755,11 +784,15 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for Immutabl
self
.block
.num_layers
()
self
.block
.num_layers
()
}
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
fn
num_outer_dims
(
&
self
)
->
usize
{
self
.block
.
layer_view
(
layer_idx
)
self
.block
.
num_outer_dims
(
)
}
}
fn
layer_view_mut
(
&
mut
self
,
_
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
self
.block
.layer_view
(
layer_idx
,
outer_idx
)
}
fn
layer_view_mut
(
&
mut
self
,
_
:
usize
,
_
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
// This should never be called since ImmutableBlock is immutable,
// This should never be called since ImmutableBlock is immutable,
// but we need to implement the full trait
// but we need to implement the full trait
Err
(
BlockError
::
InvalidState
(
Err
(
BlockError
::
InvalidState
(
...
@@ -946,6 +979,7 @@ pub mod nixl {
...
@@ -946,6 +979,7 @@ pub mod nixl {
fn
as_layer_descriptor
(
fn
as_layer_descriptor
(
&
self
,
&
self
,
layer_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
;
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
;
}
}
...
@@ -961,6 +995,7 @@ pub mod nixl {
...
@@ -961,6 +995,7 @@ pub mod nixl {
fn
as_layer_descriptor_mut
(
fn
as_layer_descriptor_mut
(
&
mut
self
,
&
mut
self
,
layer_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
;
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
;
}
}
...
@@ -974,8 +1009,9 @@ pub mod nixl {
...
@@ -974,8 +1009,9 @@ pub mod nixl {
fn
as_layer_descriptor
(
fn
as_layer_descriptor
(
&
self
,
&
self
,
layer_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
{
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
{
Ok
(
self
.layer_view
(
layer_idx
)
?
.as_nixl_descriptor
())
Ok
(
self
.layer_view
(
layer_idx
,
outer_idx
)
?
.as_nixl_descriptor
())
}
}
}
}
...
@@ -989,8 +1025,11 @@ pub mod nixl {
...
@@ -989,8 +1025,11 @@ pub mod nixl {
fn
as_layer_descriptor_mut
(
fn
as_layer_descriptor_mut
(
&
mut
self
,
&
mut
self
,
layer_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
{
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
{
Ok
(
self
.layer_view_mut
(
layer_idx
)
?
.as_nixl_descriptor_mut
())
Ok
(
self
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
.as_nixl_descriptor_mut
())
}
}
}
}
...
@@ -1188,15 +1227,24 @@ pub mod nixl {
...
@@ -1188,15 +1227,24 @@ pub mod nixl {
self
.data
.num_layers
()
self
.data
.num_layers
()
}
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
NixlStorage
>>
{
fn
num_outer_dims
(
&
self
)
->
usize
{
self
.data
.layer_view
(
layer_idx
)
self
.data
.num_outer_dims
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerView
<
NixlStorage
>>
{
self
.data
.layer_view
(
layer_idx
,
outer_idx
)
}
}
fn
layer_view_mut
(
fn
layer_view_mut
(
&
mut
self
,
&
mut
self
,
layer_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
NixlStorage
>>
{
)
->
BlockResult
<
view
::
LayerViewMut
<
NixlStorage
>>
{
self
.data
.layer_view_mut
(
layer_idx
)
self
.data
.layer_view_mut
(
layer_idx
,
outer_idx
)
}
}
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
NixlStorage
>>
{
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
NixlStorage
>>
{
...
@@ -1224,8 +1272,9 @@ pub mod nixl {
...
@@ -1224,8 +1272,9 @@ pub mod nixl {
fn
as_layer_descriptor
(
fn
as_layer_descriptor
(
&
self
,
&
self
,
layer_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
{
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
{
self
.data
.as_layer_descriptor
(
layer_idx
)
self
.data
.as_layer_descriptor
(
layer_idx
,
outer_idx
)
}
}
}
}
...
@@ -1244,8 +1293,9 @@ pub mod nixl {
...
@@ -1244,8 +1293,9 @@ pub mod nixl {
fn
as_layer_descriptor_mut
(
fn
as_layer_descriptor_mut
(
&
mut
self
,
&
mut
self
,
layer_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
{
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
{
self
.data
.as_layer_descriptor_mut
(
layer_idx
)
self
.data
.as_layer_descriptor_mut
(
layer_idx
,
outer_idx
)
}
}
}
}
...
@@ -1733,7 +1783,8 @@ mod tests {
...
@@ -1733,7 +1783,8 @@ mod tests {
let
config
=
LayoutConfig
::
builder
()
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_blocks
(
10
)
.num_layers
(
2
)
.num_layers
(
3
)
.outer_dim
(
2
)
.page_size
(
4
)
.page_size
(
4
)
.inner_dim
(
13
)
.inner_dim
(
13
)
.build
()
.build
()
...
@@ -1780,6 +1831,7 @@ mod tests {
...
@@ -1780,6 +1831,7 @@ mod tests {
let
config
=
LayoutConfig
::
builder
()
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_blocks
(
10
)
.num_layers
(
2
)
.num_layers
(
2
)
.outer_dim
(
1
)
.page_size
(
4
)
.page_size
(
4
)
.inner_dim
(
13
)
.inner_dim
(
13
)
.build
()
.build
()
...
@@ -1803,12 +1855,12 @@ mod tests {
...
@@ -1803,12 +1855,12 @@ mod tests {
assert_eq!
(
mutable_block
.num_layers
(),
2
);
assert_eq!
(
mutable_block
.num_layers
(),
2
);
// Test layer_view()
// Test layer_view()
let
layer_view
=
mutable_block
.layer_view
(
0
)
.unwrap
();
let
layer_view
=
mutable_block
.layer_view
(
0
,
0
)
.unwrap
();
assert_eq!
(
layer_view
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert_eq!
(
layer_view
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert
!
(
!
unsafe
{
layer_view
.as_ptr
()
}
.is_null
());
assert
!
(
!
unsafe
{
layer_view
.as_ptr
()
}
.is_null
());
// Test layer_view_mut()
// Test layer_view_mut()
let
mut
layer_view_mut
=
mutable_block
.layer_view_mut
(
1
)
.unwrap
();
let
mut
layer_view_mut
=
mutable_block
.layer_view_mut
(
1
,
0
)
.unwrap
();
assert_eq!
(
layer_view_mut
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert_eq!
(
layer_view_mut
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert
!
(
!
unsafe
{
layer_view_mut
.as_mut_ptr
()
}
.is_null
());
assert
!
(
!
unsafe
{
layer_view_mut
.as_mut_ptr
()
}
.is_null
());
...
@@ -1833,6 +1885,7 @@ mod tests {
...
@@ -1833,6 +1885,7 @@ mod tests {
let
config
=
LayoutConfig
::
builder
()
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_blocks
(
10
)
.num_layers
(
2
)
.num_layers
(
2
)
.outer_dim
(
1
)
.page_size
(
4
)
.page_size
(
4
)
.inner_dim
(
13
)
.inner_dim
(
13
)
.build
()
.build
()
...
@@ -1860,7 +1913,7 @@ mod tests {
...
@@ -1860,7 +1913,7 @@ mod tests {
assert_eq!
(
immutable_block
.num_layers
(),
2
);
assert_eq!
(
immutable_block
.num_layers
(),
2
);
// Test layer_view()
// Test layer_view()
let
layer_view
=
immutable_block
.layer_view
(
0
)
.unwrap
();
let
layer_view
=
immutable_block
.layer_view
(
0
,
0
)
.unwrap
();
assert_eq!
(
layer_view
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert_eq!
(
layer_view
.size
(),
4
*
13
*
2
);
// page_size x inner_dim x dtype_bytes
assert
!
(
!
unsafe
{
layer_view
.as_ptr
()
}
.is_null
());
assert
!
(
!
unsafe
{
layer_view
.as_ptr
()
}
.is_null
());
...
@@ -1872,7 +1925,7 @@ mod tests {
...
@@ -1872,7 +1925,7 @@ mod tests {
// Test that mutable methods return errors
// Test that mutable methods return errors
let
mut
mut_immutable_block
=
immutable_block
;
// We need a mutable reference for these tests
let
mut
mut_immutable_block
=
immutable_block
;
// We need a mutable reference for these tests
let
layer_view_mut_res
=
mut_immutable_block
.layer_view_mut
(
0
);
let
layer_view_mut_res
=
mut_immutable_block
.layer_view_mut
(
0
,
0
);
assert
!
(
layer_view_mut_res
.is_err
());
assert
!
(
layer_view_mut_res
.is_err
());
if
let
Err
(
BlockError
::
InvalidState
(
msg
))
=
layer_view_mut_res
{
if
let
Err
(
BlockError
::
InvalidState
(
msg
))
=
layer_view_mut_res
{
assert
!
(
msg
.contains
(
"immutable block"
));
assert
!
(
msg
.contains
(
"immutable block"
));
...
...
lib/llm/src/block_manager/block/transfer/cuda.rs
View file @
80256acf
...
@@ -112,8 +112,9 @@ where
...
@@ -112,8 +112,9 @@ where
}
}
for
layer_idx
in
layer_range
{
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
...
@@ -126,6 +127,7 @@ where
...
@@ -126,6 +127,7 @@ where
)
?
;
)
?
;
}
}
}
}
}
Ok
(())
Ok
(())
}
}
...
...
lib/llm/src/block_manager/block/transfer/memcpy.rs
View file @
80256acf
...
@@ -57,14 +57,16 @@ where
...
@@ -57,14 +57,16 @@ where
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
for
layer_idx
in
layer_range
{
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
unsafe
{
memcpy
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
());
memcpy
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
());
}
}
}
}
}
Ok
(())
Ok
(())
}
}
...
...
lib/llm/src/block_manager/block/transfer/nixl.rs
View file @
80256acf
...
@@ -132,8 +132,9 @@ where
...
@@ -132,8 +132,9 @@ where
// }
// }
for
layer_idx
in
layer_range
{
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
...
@@ -154,6 +155,7 @@ where
...
@@ -154,6 +155,7 @@ where
)
?
;
)
?
;
}
}
}
}
}
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
...
...
lib/llm/src/block_manager/config.rs
View file @
80256acf
...
@@ -71,6 +71,9 @@ pub struct KvManagerModelConfig {
...
@@ -71,6 +71,9 @@ pub struct KvManagerModelConfig {
#[validate(range(min
=
1
))]
#[validate(range(min
=
1
))]
pub
num_layers
:
usize
,
pub
num_layers
:
usize
,
#[validate(range(min
=
1
,
max
=
2
))]
pub
outer_dim
:
usize
,
#[validate(range(min
=
1
))]
#[validate(range(min
=
1
))]
pub
page_size
:
usize
,
pub
page_size
:
usize
,
...
...
lib/llm/src/block_manager/layout.rs
View file @
80256acf
...
@@ -84,6 +84,7 @@
...
@@ -84,6 +84,7 @@
//! let config = LayoutConfig::builder()
//! let config = LayoutConfig::builder()
//! .num_blocks(10)
//! .num_blocks(10)
//! .num_layers(4)
//! .num_layers(4)
//! .outer_dim(1)
//! .page_size(16)
//! .page_size(16)
//! .inner_dim(128)
//! .inner_dim(128)
//! .dtype(DType::FP16)
//! .dtype(DType::FP16)
...
@@ -109,8 +110,12 @@
...
@@ -109,8 +110,12 @@
//! which extends these layout concepts for NIXL (NVIDIA Interface eXchange Layer), enabling
//! which extends these layout concepts for NIXL (NVIDIA Interface eXchange Layer), enabling
//! layouts to be registered and serialized for use in distributed environments.
//! layouts to be registered and serialized for use in distributed environments.
// todo: coming soon...
// pub mod distributed;
pub
mod
nixl
;
pub
mod
nixl
;
use
derive_getters
::
Getters
;
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
crate
::
block_manager
::
storage
::{
Storage
,
StorageAllocator
};
use
crate
::
block_manager
::
storage
::{
Storage
,
StorageAllocator
};
...
@@ -138,6 +143,9 @@ pub enum LayoutError {
...
@@ -138,6 +143,9 @@ pub enum LayoutError {
#[error(
"Invalid layer index: {0}"
)]
#[error(
"Invalid layer index: {0}"
)]
InvalidLayerIndex
(
usize
),
InvalidLayerIndex
(
usize
),
#[error(
"Invalid outer index: {0}"
)]
InvalidOuterIndex
(
usize
),
#[error(
"Operation failed: {0}"
)]
#[error(
"Operation failed: {0}"
)]
OperationFailed
(
String
),
OperationFailed
(
String
),
...
@@ -165,10 +173,18 @@ pub enum LayoutType {
...
@@ -165,10 +173,18 @@ pub enum LayoutType {
// Null,
// Null,
}
}
/// Local Memory Region
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq,
Serialize,
Deserialize,
Getters)]
pub
struct
LocalMemoryRegion
{
#[getter(copy)]
addr
:
usize
,
#[getter(copy)]
size
:
usize
,
}
/// Core trait for block layouts
/// Core trait for block layouts
pub
trait
BlockLayout
:
pub
trait
BlockLayout
:
BlockLayoutConfig
+
Send
+
Sync
+
std
::
fmt
::
Debug
{
BlockLayoutConfig
+
BlockLayoutLookup
+
Send
+
Sync
+
std
::
fmt
::
Debug
{
/// The type of storage this layout uses
/// The type of storage this layout uses
type
StorageType
:
Storage
;
type
StorageType
:
Storage
;
...
@@ -180,6 +196,21 @@ pub trait BlockLayout:
...
@@ -180,6 +196,21 @@ pub trait BlockLayout:
/// Storage type for the layout
/// Storage type for the layout
fn
storage_type
(
&
self
)
->
StorageType
;
fn
storage_type
(
&
self
)
->
StorageType
;
/// Get the memory region for a specific page [page_size, inner_dim]
///
/// # Arguments
///
/// * `block_idx` - The index of the block
/// * `layer_idx` - The index of the layer
/// * `outer_idx` - The index of the outer dimension, e.g. if
///
fn
memory_region
(
&
self
,
block_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
Result
<
LocalMemoryRegion
,
LayoutError
>
;
}
}
/// Configuration for block layouts
/// Configuration for block layouts
...
@@ -193,6 +224,12 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
...
@@ -193,6 +224,12 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
/// Returns the number of layers per block
/// Returns the number of layers per block
fn
num_layers
(
&
self
)
->
usize
;
fn
num_layers
(
&
self
)
->
usize
;
/// Returns the number of outer dimensions per block
/// In some cases, K and V might be indexed separately, so in that example one might have 2 outer dimensions
/// For MLA, this is 1.
/// The location of the outer dimension in the shape of the tensor layout is defined by the layout type.
fn
outer_dim
(
&
self
)
->
usize
;
/// Returns the size of each block in bytes
/// Returns the size of each block in bytes
fn
page_size
(
&
self
)
->
usize
;
fn
page_size
(
&
self
)
->
usize
;
...
@@ -200,15 +237,6 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
...
@@ -200,15 +237,6 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
fn
inner_dim
(
&
self
)
->
usize
;
fn
inner_dim
(
&
self
)
->
usize
;
}
}
/// Trait for looking up memory regions in a block layout
pub
trait
BlockLayoutLookup
{
/// Get the memory region for a specific page [page_size, inner_dim]
fn
memory_region_addr
(
&
self
,
block_idx
:
usize
,
layer_idx
:
usize
)
->
Result
<
u64
,
LayoutError
>
;
/// Get the memory region for a specific page [page_size, inner_dim]
fn
memory_region_size
(
&
self
)
->
usize
;
}
/// Configuration for block layouts
/// Configuration for block layouts
#[derive(Debug,
Clone,
Builder,
Validate,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Builder,
Validate,
Serialize,
Deserialize)]
pub
struct
LayoutConfig
{
pub
struct
LayoutConfig
{
...
@@ -220,6 +248,10 @@ pub struct LayoutConfig {
...
@@ -220,6 +248,10 @@ pub struct LayoutConfig {
#[validate(range(min
=
1
))]
#[validate(range(min
=
1
))]
pub
num_layers
:
usize
,
pub
num_layers
:
usize
,
/// Number of outer dimensions
#[validate(range(min
=
1
,
max
=
2
))]
pub
outer_dim
:
usize
,
/// Page size
/// Page size
#[validate(range(min
=
1
))]
#[validate(range(min
=
1
))]
pub
page_size
:
usize
,
pub
page_size
:
usize
,
...
@@ -268,10 +300,24 @@ fn align_up(value: usize, alignment: usize) -> usize {
...
@@ -268,10 +300,24 @@ fn align_up(value: usize, alignment: usize) -> usize {
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
(
crate
)
struct
FullyContiguousConfig
{
pub
(
crate
)
struct
FullyContiguousConfig
{
inner
:
LayoutConfig
,
inner
:
LayoutConfig
,
/// Minimum contiguous memory region size
/// Inner dimension * page size * dtype size
memory_region_size
:
usize
,
memory_region_size
:
usize
,
/// Stride between outer dimensions
outer_dim_stride_in_bytes
:
usize
,
/// Stride between layers
layer_stride_in_bytes
:
usize
,
layer_stride_in_bytes
:
usize
,
/// Natural block stride
natural_block_stride
:
usize
,
natural_block_stride
:
usize
,
/// Block stride in bytes
block_stride_in_bytes
:
usize
,
// Aligned if necessary
block_stride_in_bytes
:
usize
,
// Aligned if necessary
/// Size of the layout data itself (post base offset)
layout_data_bytes
:
usize
,
// Size of the layout data itself (post base offset)
layout_data_bytes
:
usize
,
// Size of the layout data itself (post base offset)
}
}
...
@@ -284,7 +330,8 @@ impl FullyContiguousConfig {
...
@@ -284,7 +330,8 @@ impl FullyContiguousConfig {
let
alignment
=
config
.alignment
;
let
alignment
=
config
.alignment
;
let
memory_region_size
=
config
.page_size
*
config
.inner_dim
*
config
.dtype
.size_in_bytes
();
let
memory_region_size
=
config
.page_size
*
config
.inner_dim
*
config
.dtype
.size_in_bytes
();
let
layer_stride_in_bytes
=
memory_region_size
;
let
outer_dim_stride_in_bytes
=
memory_region_size
;
let
layer_stride_in_bytes
=
outer_dim_stride_in_bytes
*
config
.outer_dim
;
let
natural_block_stride
=
config
.num_layers
*
layer_stride_in_bytes
;
let
natural_block_stride
=
config
.num_layers
*
layer_stride_in_bytes
;
let
block_stride_in_bytes
=
if
alignment
>
1
{
let
block_stride_in_bytes
=
if
alignment
>
1
{
...
@@ -299,6 +346,7 @@ impl FullyContiguousConfig {
...
@@ -299,6 +346,7 @@ impl FullyContiguousConfig {
Ok
(
Self
{
Ok
(
Self
{
inner
:
config
,
inner
:
config
,
memory_region_size
,
memory_region_size
,
outer_dim_stride_in_bytes
,
layer_stride_in_bytes
,
layer_stride_in_bytes
,
natural_block_stride
,
natural_block_stride
,
block_stride_in_bytes
,
block_stride_in_bytes
,
...
@@ -331,6 +379,10 @@ impl BlockLayoutConfig for FullyContiguousConfig {
...
@@ -331,6 +379,10 @@ impl BlockLayoutConfig for FullyContiguousConfig {
self
.inner.num_layers
self
.inner.num_layers
}
}
fn
outer_dim
(
&
self
)
->
usize
{
self
.inner.outer_dim
}
fn
page_size
(
&
self
)
->
usize
{
fn
page_size
(
&
self
)
->
usize
{
self
.inner.page_size
self
.inner.page_size
}
}
...
@@ -508,6 +560,39 @@ impl<S: Storage> BlockLayout for FullyContiguous<S> {
...
@@ -508,6 +560,39 @@ impl<S: Storage> BlockLayout for FullyContiguous<S> {
fn
storage_type
(
&
self
)
->
StorageType
{
fn
storage_type
(
&
self
)
->
StorageType
{
self
.storage_type
.clone
()
self
.storage_type
.clone
()
}
}
fn
memory_region
(
&
self
,
block_idx
:
usize
,
layer_idx
:
usize
,
outer_idx
:
usize
,
)
->
Result
<
LocalMemoryRegion
,
LayoutError
>
{
if
block_idx
>=
self
.num_blocks
()
{
return
Err
(
LayoutError
::
InvalidBlockIndex
(
block_idx
));
}
if
layer_idx
>=
self
.num_layers
()
{
return
Err
(
LayoutError
::
InvalidLayerIndex
(
layer_idx
));
}
if
outer_idx
>=
self
.outer_dim
()
{
return
Err
(
LayoutError
::
InvalidOuterIndex
(
outer_idx
));
}
// Start from the aligned base address
let
aligned_start_addr
=
self
.storage
.addr
()
as
usize
+
self
.base_offset
;
// Calculate offset relative to the aligned start using stored config
let
block_offset
=
block_idx
*
self
.config.block_stride_in_bytes
;
let
layer_offset
=
layer_idx
*
self
.config.layer_stride_in_bytes
;
let
outer_offset
=
outer_idx
*
self
.config.outer_dim_stride_in_bytes
;
let
final_addr
=
aligned_start_addr
+
block_offset
+
layer_offset
+
outer_offset
;
Ok
(
LocalMemoryRegion
{
addr
:
final_addr
,
size
:
self
.config.memory_region_size
,
})
}
}
}
impl
<
S
:
Storage
>
BlockLayoutConfig
for
FullyContiguous
<
S
>
{
impl
<
S
:
Storage
>
BlockLayoutConfig
for
FullyContiguous
<
S
>
{
...
@@ -523,6 +608,10 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
...
@@ -523,6 +608,10 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
self
.config.inner.num_layers
self
.config.inner.num_layers
}
}
fn
outer_dim
(
&
self
)
->
usize
{
self
.config.inner.outer_dim
}
fn
page_size
(
&
self
)
->
usize
{
fn
page_size
(
&
self
)
->
usize
{
self
.config.inner.page_size
self
.config.inner.page_size
}
}
...
@@ -532,33 +621,6 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
...
@@ -532,33 +621,6 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
}
}
}
}
impl
<
S
:
Storage
>
BlockLayoutLookup
for
FullyContiguous
<
S
>
{
fn
memory_region_addr
(
&
self
,
block_idx
:
usize
,
layer_idx
:
usize
)
->
Result
<
u64
,
LayoutError
>
{
if
block_idx
>=
self
.num_blocks
()
{
return
Err
(
LayoutError
::
InvalidBlockIndex
(
block_idx
));
}
if
layer_idx
>=
self
.num_layers
()
{
return
Err
(
LayoutError
::
InvalidLayerIndex
(
layer_idx
));
}
// Start from the aligned base address
let
aligned_start_addr
=
self
.storage
.addr
()
+
self
.base_offset
as
u64
;
// Calculate offset relative to the aligned start using stored config
let
block_offset
=
block_idx
*
self
.config.block_stride_in_bytes
;
let
layer_offset
=
layer_idx
*
self
.config.layer_stride_in_bytes
;
let
final_addr
=
aligned_start_addr
+
block_offset
as
u64
+
layer_offset
as
u64
;
Ok
(
final_addr
)
}
fn
memory_region_size
(
&
self
)
->
usize
{
// Access via stored dims
self
.config.memory_region_size
}
}
#[allow(missing_docs)]
#[allow(missing_docs)]
#[cfg(test)]
#[cfg(test)]
pub
mod
tests
{
pub
mod
tests
{
...
@@ -570,6 +632,7 @@ pub mod tests {
...
@@ -570,6 +632,7 @@ pub mod tests {
const
NUM_BLOCKS
:
usize
=
7
;
const
NUM_BLOCKS
:
usize
=
7
;
const
NUM_LAYERS
:
usize
=
5
;
const
NUM_LAYERS
:
usize
=
5
;
const
OUTER_DIM
:
usize
=
2
;
const
PAGE_SIZE
:
usize
=
4
;
const
PAGE_SIZE
:
usize
=
4
;
const
INNER_DIM
:
usize
=
13
;
const
INNER_DIM
:
usize
=
13
;
const
DTYPE
:
DType
=
DType
::
FP32
;
// Example dtype
const
DTYPE
:
DType
=
DType
::
FP32
;
// Example dtype
...
@@ -592,6 +655,7 @@ pub mod tests {
...
@@ -592,6 +655,7 @@ pub mod tests {
let
config
=
LayoutConfig
{
let
config
=
LayoutConfig
{
num_blocks
:
NUM_BLOCKS
,
num_blocks
:
NUM_BLOCKS
,
num_layers
:
NUM_LAYERS
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
OUTER_DIM
,
page_size
:
PAGE_SIZE
,
page_size
:
PAGE_SIZE
,
inner_dim
:
INNER_DIM
,
inner_dim
:
INNER_DIM
,
alignment
:
alignment
.unwrap_or
(
1
),
alignment
:
alignment
.unwrap_or
(
1
),
...
@@ -606,6 +670,7 @@ pub mod tests {
...
@@ -606,6 +670,7 @@ pub mod tests {
let
config
=
LayoutConfig
::
builder
()
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
NUM_BLOCKS
)
.num_blocks
(
NUM_BLOCKS
)
.num_layers
(
NUM_LAYERS
)
.num_layers
(
NUM_LAYERS
)
.outer_dim
(
OUTER_DIM
)
.page_size
(
PAGE_SIZE
)
.page_size
(
PAGE_SIZE
)
.inner_dim
(
INNER_DIM
)
.inner_dim
(
INNER_DIM
)
.alignment
(
3
)
.alignment
(
3
)
...
@@ -632,6 +697,7 @@ pub mod tests {
...
@@ -632,6 +697,7 @@ pub mod tests {
let
config
=
LayoutConfig
{
let
config
=
LayoutConfig
{
num_blocks
:
NUM_BLOCKS
,
num_blocks
:
NUM_BLOCKS
,
num_layers
:
NUM_LAYERS
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
OUTER_DIM
,
page_size
:
PAGE_SIZE
,
page_size
:
PAGE_SIZE
,
inner_dim
:
INNER_DIM
,
inner_dim
:
INNER_DIM
,
alignment
:
1
,
alignment
:
1
,
...
@@ -656,17 +722,11 @@ pub mod tests {
...
@@ -656,17 +722,11 @@ pub mod tests {
assert_eq!
(
layout
.num_blocks
(),
NUM_BLOCKS
);
assert_eq!
(
layout
.num_blocks
(),
NUM_BLOCKS
);
assert_eq!
(
layout
.num_layers
(),
NUM_LAYERS
);
assert_eq!
(
layout
.num_layers
(),
NUM_LAYERS
);
assert_eq!
(
layout
.outer_dim
(),
OUTER_DIM
);
assert_eq!
(
layout
.page_size
(),
PAGE_SIZE
);
assert_eq!
(
layout
.page_size
(),
PAGE_SIZE
);
assert_eq!
(
layout
.inner_dim
(),
INNER_DIM
);
assert_eq!
(
layout
.inner_dim
(),
INNER_DIM
);
}
}
#[test]
fn
test_fc_memory_region_size
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
expected_region_size
=
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
();
assert_eq!
(
layout
.memory_region_size
(),
expected_region_size
);
}
#[test]
#[test]
fn
test_fc_offset_calculation
()
{
fn
test_fc_offset_calculation
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
...
@@ -680,7 +740,7 @@ pub mod tests {
...
@@ -680,7 +740,7 @@ pub mod tests {
let
expected_offset_0_0
=
let
expected_offset_0_0
=
calculate_expected_offset
(
base_addr
,
0
,
0
,
block_stride
,
layer_stride
);
calculate_expected_offset
(
base_addr
,
0
,
0
,
block_stride
,
layer_stride
);
assert_eq!
(
assert_eq!
(
layout
.memory_region
_addr
(
0
,
0
)
.unwrap
(),
layout
.memory_region
(
0
,
0
,
0
)
.unwrap
()
.addr
as
u64
,
expected_offset_0_0
expected_offset_0_0
);
);
...
@@ -689,7 +749,7 @@ pub mod tests {
...
@@ -689,7 +749,7 @@ pub mod tests {
let
expected_offset_0_last
=
let
expected_offset_0_last
=
calculate_expected_offset
(
base_addr
,
0
,
last_layer_idx
,
block_stride
,
layer_stride
);
calculate_expected_offset
(
base_addr
,
0
,
last_layer_idx
,
block_stride
,
layer_stride
);
assert_eq!
(
assert_eq!
(
layout
.memory_region
_addr
(
0
,
last_layer_idx
)
.unwrap
(),
layout
.memory_region
(
0
,
last_layer_idx
,
0
)
.unwrap
()
.addr
as
u64
,
expected_offset_0_last
expected_offset_0_last
);
);
...
@@ -698,7 +758,7 @@ pub mod tests {
...
@@ -698,7 +758,7 @@ pub mod tests {
let
expected_offset_last_0
=
let
expected_offset_last_0
=
calculate_expected_offset
(
base_addr
,
last_block_idx
,
0
,
block_stride
,
layer_stride
);
calculate_expected_offset
(
base_addr
,
last_block_idx
,
0
,
block_stride
,
layer_stride
);
assert_eq!
(
assert_eq!
(
layout
.memory_region
_addr
(
last_block_idx
,
0
)
.unwrap
(),
layout
.memory_region
(
last_block_idx
,
0
,
0
)
.unwrap
()
.addr
as
u64
,
expected_offset_last_0
expected_offset_last_0
);
);
...
@@ -712,8 +772,9 @@ pub mod tests {
...
@@ -712,8 +772,9 @@ pub mod tests {
);
);
assert_eq!
(
assert_eq!
(
layout
layout
.memory_region_addr
(
last_block_idx
,
last_layer_idx
)
.memory_region
(
last_block_idx
,
last_layer_idx
,
0
)
.unwrap
(),
.unwrap
()
.addr
as
u64
,
expected_offset_last_last
expected_offset_last_last
);
);
...
@@ -729,8 +790,9 @@ pub mod tests {
...
@@ -729,8 +790,9 @@ pub mod tests {
);
);
assert_eq!
(
assert_eq!
(
layout
layout
.memory_region_addr
(
mid_block_idx
,
mid_layer_idx
)
.memory_region
(
mid_block_idx
,
mid_layer_idx
,
0
)
.unwrap
(),
.unwrap
()
.addr
as
u64
,
expected_offset_mid_mid
expected_offset_mid_mid
);
);
}
}
...
@@ -738,7 +800,7 @@ pub mod tests {
...
@@ -738,7 +800,7 @@ pub mod tests {
#[test]
#[test]
fn
test_fc_invalid_block_index
()
{
fn
test_fc_invalid_block_index
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
result
=
layout
.memory_region
_addr
(
NUM_BLOCKS
,
0
);
// Index == num_blocks (out of bounds)
let
result
=
layout
.memory_region
(
NUM_BLOCKS
,
0
,
0
);
// Index == num_blocks (out of bounds)
assert
!
(
result
.is_err
());
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
assert
!
(
matches!
(
result
.err
()
.unwrap
(),
result
.err
()
.unwrap
(),
...
@@ -749,7 +811,7 @@ pub mod tests {
...
@@ -749,7 +811,7 @@ pub mod tests {
#[test]
#[test]
fn
test_fc_invalid_layer_index
()
{
fn
test_fc_invalid_layer_index
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
result
=
layout
.memory_region
_addr
(
0
,
NUM_LAYERS
);
// Index == num_layers (out of bounds)
let
result
=
layout
.memory_region
(
0
,
NUM_LAYERS
,
0
);
// Index == num_layers (out of bounds)
assert
!
(
result
.is_err
());
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
assert
!
(
matches!
(
result
.err
()
.unwrap
(),
result
.err
()
.unwrap
(),
...
@@ -757,12 +819,24 @@ pub mod tests {
...
@@ -757,12 +819,24 @@ pub mod tests {
));
));
}
}
#[test]
fn
test_fc_invalid_outer_index
()
{
let
layout
=
setup_layout
(
None
)
.expect
(
"Layout setup failed"
);
let
result
=
layout
.memory_region
(
0
,
0
,
OUTER_DIM
);
// Index == num_outer_dims (out of bounds)
assert
!
(
result
.is_err
());
assert
!
(
matches!
(
result
.err
()
.unwrap
(),
LayoutError
::
InvalidOuterIndex
(
OUTER_DIM
)
));
}
#[test]
#[test]
fn
test_fc_allocation_system
()
{
fn
test_fc_allocation_system
()
{
init_logging
();
init_logging
();
let
config
=
LayoutConfig
{
let
config
=
LayoutConfig
{
num_blocks
:
NUM_BLOCKS
,
num_blocks
:
NUM_BLOCKS
,
num_layers
:
NUM_LAYERS
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
OUTER_DIM
,
page_size
:
PAGE_SIZE
,
page_size
:
PAGE_SIZE
,
inner_dim
:
INNER_DIM
,
inner_dim
:
INNER_DIM
,
alignment
:
1
,
alignment
:
1
,
...
@@ -788,7 +862,7 @@ pub mod tests {
...
@@ -788,7 +862,7 @@ pub mod tests {
assert_eq!
(
assert_eq!
(
layout
.storage
.size
(),
layout
.storage
.size
(),
NUM_BLOCKS
*
NUM_LAYERS
*
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
()
NUM_BLOCKS
*
NUM_LAYERS
*
OUTER_DIM
*
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
()
);
);
}
}
...
@@ -800,6 +874,7 @@ pub mod tests {
...
@@ -800,6 +874,7 @@ pub mod tests {
let
config
=
LayoutConfig
{
let
config
=
LayoutConfig
{
num_blocks
:
NUM_BLOCKS
,
num_blocks
:
NUM_BLOCKS
,
num_layers
:
NUM_LAYERS
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
OUTER_DIM
,
page_size
:
PAGE_SIZE
,
page_size
:
PAGE_SIZE
,
inner_dim
:
INNER_DIM
,
inner_dim
:
INNER_DIM
,
alignment
:
ALIGNMENT
,
alignment
:
ALIGNMENT
,
...
@@ -810,11 +885,11 @@ pub mod tests {
...
@@ -810,11 +885,11 @@ pub mod tests {
let
memory_region_size
=
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
();
let
memory_region_size
=
PAGE_SIZE
*
INNER_DIM
*
DTYPE
.size_in_bytes
();
assert_eq!
(
memory_region_size
,
208
);
assert_eq!
(
memory_region_size
,
208
);
let
natural_block_stride
=
NUM_LAYERS
*
memory_region_size
;
let
natural_block_stride
=
OUTER_DIM
*
NUM_LAYERS
*
memory_region_size
;
assert_eq!
(
natural_block_stride
,
104
0
);
assert_eq!
(
natural_block_stride
,
208
0
);
let
aligned_block_stride
=
align_up
(
natural_block_stride
,
ALIGNMENT
);
let
aligned_block_stride
=
align_up
(
natural_block_stride
,
ALIGNMENT
);
assert_eq!
(
aligned_block_stride
,
1280
);
assert_eq!
(
aligned_block_stride
,
2304
);
// Calculate the expected *allocated* size (data + initial padding)
// Calculate the expected *allocated* size (data + initial padding)
let
fc_config
=
FullyContiguousConfig
::
new
(
config
.clone
())
.unwrap
();
let
fc_config
=
FullyContiguousConfig
::
new
(
config
.clone
())
.unwrap
();
...
@@ -844,40 +919,40 @@ pub mod tests {
...
@@ -844,40 +919,40 @@ pub mod tests {
// Check alignment of block starts
// Check alignment of block starts
let
addr_block_0
=
layout
let
addr_block_0
=
layout
.memory_region
_addr
(
0
,
0
)
.memory_region
(
0
,
0
,
0
)
.expect
(
"Failed to get addr block 0"
);
.expect
(
"Failed to get addr block 0"
);
let
addr_block_1
=
layout
let
addr_block_1
=
layout
.memory_region
_addr
(
1
,
0
)
.memory_region
(
1
,
0
,
0
)
.expect
(
"Failed to get addr block 1"
);
.expect
(
"Failed to get addr block 1"
);
let
addr_block_2
=
layout
let
addr_block_2
=
layout
.memory_region
_addr
(
2
,
0
)
.memory_region
(
2
,
0
,
0
)
.expect
(
"Failed to get addr block 2"
);
.expect
(
"Failed to get addr block 2"
);
// All blocks should now be aligned due to base_offset adjustment
// All blocks should now be aligned due to base_offset adjustment
assert_eq!
(
assert_eq!
(
addr_block_0
%
ALIGNMENT
as
u64
,
addr_block_0
.addr
as
u64
%
ALIGNMENT
as
u64
,
0
,
0
,
"Block 0 start address is not aligned"
"Block 0 start address is not aligned"
);
);
assert_eq!
(
assert_eq!
(
addr_block_1
%
ALIGNMENT
as
u64
,
addr_block_1
.addr
as
u64
%
ALIGNMENT
as
u64
,
0
,
0
,
"Block 1 start address is not aligned"
"Block 1 start address is not aligned"
);
);
assert_eq!
(
assert_eq!
(
addr_block_2
%
ALIGNMENT
as
u64
,
addr_block_2
.addr
as
u64
%
ALIGNMENT
as
u64
,
0
,
0
,
"Block 2 start address is not aligned"
"Block 2 start address is not aligned"
);
);
// Verify the difference matches the aligned stride
// Verify the difference matches the aligned stride
assert_eq!
(
assert_eq!
(
addr_block_1
-
addr_block_0
,
addr_block_1
.addr
as
u64
-
addr_block_0
.addr
as
u64
,
aligned_block_stride
as
u64
,
aligned_block_stride
as
u64
,
"Stride between block 0 and 1 mismatch"
"Stride between block 0 and 1 mismatch"
);
);
assert_eq!
(
assert_eq!
(
addr_block_2
-
addr_block_1
,
addr_block_2
.addr
as
u64
-
addr_block_1
.addr
as
u64
,
aligned_block_stride
as
u64
,
aligned_block_stride
as
u64
,
"Stride between block 1 and 2 mismatch"
"Stride between block 1 and 2 mismatch"
);
);
...
...
lib/llm/src/block_manager/layout/nixl.rs
View file @
80256acf
...
@@ -332,6 +332,7 @@ mod tests {
...
@@ -332,6 +332,7 @@ mod tests {
let
config
=
LayoutConfig
::
builder
()
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_blocks
(
10
)
.num_layers
(
2
)
.num_layers
(
2
)
.outer_dim
(
2
)
.page_size
(
4
)
.page_size
(
4
)
.inner_dim
(
13
)
.inner_dim
(
13
)
.build
()
.build
()
...
...
lib/llm/src/block_manager/offload.rs
View file @
80256acf
...
@@ -506,6 +506,7 @@ mod tests {
...
@@ -506,6 +506,7 @@ mod tests {
let
mut
config
=
LayoutConfig
{
let
mut
config
=
LayoutConfig
{
num_blocks
:
device_blocks
,
num_blocks
:
device_blocks
,
num_layers
:
8
,
num_layers
:
8
,
outer_dim
:
1
,
page_size
:
BLOCK_SIZE
,
page_size
:
BLOCK_SIZE
,
inner_dim
:
1024
,
inner_dim
:
1024
,
alignment
:
1
,
alignment
:
1
,
...
...
lib/llm/src/block_manager/pool/inactive.rs
View file @
80256acf
...
@@ -582,6 +582,7 @@ pub(crate) mod tests {
...
@@ -582,6 +582,7 @@ pub(crate) mod tests {
let
config
=
LayoutConfigBuilder
::
default
()
let
config
=
LayoutConfigBuilder
::
default
()
.num_blocks
(
num_blocks
)
.num_blocks
(
num_blocks
)
.num_layers
(
61
)
.num_layers
(
61
)
.outer_dim
(
1
)
.page_size
(
16
)
.page_size
(
16
)
.inner_dim
(
576
)
.inner_dim
(
576
)
.build
()
.build
()
...
...
lib/llm/src/block_manager/state.rs
View file @
80256acf
...
@@ -115,6 +115,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -115,6 +115,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
layout_builder
layout_builder
.num_layers
(
model
.num_layers
)
.num_layers
(
model
.num_layers
)
.outer_dim
(
model
.outer_dim
)
.page_size
(
model
.page_size
)
.page_size
(
model
.page_size
)
.inner_dim
(
model
.inner_dim
)
.inner_dim
(
model
.inner_dim
)
.dtype
(
model
.dtype
);
.dtype
(
model
.dtype
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment