Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
976bb70a
Unverified
Commit
976bb70a
authored
Feb 17, 2026
by
Ryan Olson
Committed by
GitHub
Feb 18, 2026
Browse files
feat: add KVBM memory management enhancements (DIS-1311) (#5532)
parent
57bdfea9
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
805 additions
and
31 deletions
+805
-31
lib/memory/src/tensor.rs
lib/memory/src/tensor.rs
+623
-0
lib/memory/src/tests.rs
lib/memory/src/tests.rs
+182
-2
lib/memory/src/torch.rs
lib/memory/src/torch.rs
+0
-29
No files found.
lib/memory/src/tensor.rs
0 → 100644
View file @
976bb70a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tensor abstraction built on top of MemoryDescriptor.
//!
//! A tensor is memory with shape, stride, and element size metadata.
//! The underlying memory could be externally owned, self-owned, or a view.
use
super
::
nixl
::{
self
,
NixlDescriptor
};
use
super
::{
MemoryDescriptor
,
StorageKind
};
use
std
::
any
::
Any
;
use
std
::
sync
::
Arc
;
/// A tensor is memory with shape, stride, and element size metadata.
///
/// This trait extends [`MemoryDescriptor`] with tensor-specific metadata.
/// The underlying memory could be externally owned, self-owned, or a view.
///
/// # Shape and Stride
///
/// - `shape()` returns the number of elements in each dimension
/// - `stride()` returns the number of elements to skip when incrementing each dimension
/// - `element_size()` returns the number of bytes per element
///
/// For a contiguous tensor with shape `[2, 3, 4]`:
/// - stride would be `[12, 4, 1]` (row-major/C order)
/// - total elements = 2 * 3 * 4 = 24
/// - total bytes = 24 * element_size()
pub
trait
TensorDescriptor
:
MemoryDescriptor
{
/// Shape of the tensor (number of elements per dimension).
fn
shape
(
&
self
)
->
&
[
usize
];
/// Stride of the tensor (elements to skip per dimension).
///
/// `stride[i]` indicates how many elements to skip when incrementing dimension `i`.
fn
stride
(
&
self
)
->
&
[
usize
];
/// Number of bytes per element.
fn
element_size
(
&
self
)
->
usize
;
}
// =============================================================================
// Helper methods for TensorDescriptor
// =============================================================================
/// Extension trait providing helper methods for tensor descriptors.
pub
trait
TensorDescriptorExt
:
TensorDescriptor
{
/// Total number of elements in the tensor (product of shape).
fn
numel
(
&
self
)
->
usize
{
self
.shape
()
.iter
()
.product
()
}
/// Number of dimensions (rank).
fn
ndim
(
&
self
)
->
usize
{
self
.shape
()
.len
()
}
/// Check if tensor is contiguous in memory (row-major/C order).
///
/// A tensor is contiguous if its strides follow the pattern where
/// the last dimension has stride 1, and each preceding dimension
/// has stride equal to the product of all following dimensions.
fn
is_contiguous
(
&
self
)
->
bool
{
let
shape
=
self
.shape
();
let
stride
=
self
.stride
();
if
shape
.is_empty
()
{
return
true
;
}
let
mut
expected_stride
=
1
;
for
i
in
(
0
..
shape
.len
())
.rev
()
{
if
stride
[
i
]
!=
expected_stride
{
return
false
;
}
expected_stride
*=
shape
[
i
];
}
true
}
/// Compute the contiguous stride for the current shape.
///
/// Returns the stride that would make this tensor contiguous
/// (row-major/C order).
fn
contiguous_stride
(
&
self
)
->
Vec
<
usize
>
{
let
shape
=
self
.shape
();
if
shape
.is_empty
()
{
return
vec!
[];
}
let
mut
stride
=
vec!
[
1
;
shape
.len
()];
for
i
in
(
0
..
shape
.len
()
-
1
)
.rev
()
{
stride
[
i
]
=
stride
[
i
+
1
]
*
shape
[
i
+
1
];
}
stride
}
/// Returns the CUDA device ID if the tensor is on a CUDA device.
fn
cuda_device_id
(
&
self
)
->
Option
<
usize
>
{
match
self
.storage_kind
()
{
StorageKind
::
Device
(
idx
)
=>
Some
(
idx
as
usize
),
_
=>
None
,
}
}
}
// Blanket impl for all TensorDescriptor types
impl
<
T
:
TensorDescriptor
+
?
Sized
>
TensorDescriptorExt
for
T
{}
// =============================================================================
// Arc<dyn TensorDescriptor> support for NixlRegisterExt
// =============================================================================
impl
nixl
::
NixlCompatible
for
Arc
<
dyn
TensorDescriptor
>
{
fn
nixl_params
(
&
self
)
->
(
*
const
u8
,
usize
,
nixl
::
MemType
,
u64
)
{
let
storage
=
self
.storage_kind
();
let
(
mem_type
,
device_id
)
=
match
storage
{
StorageKind
::
Device
(
idx
)
=>
(
nixl
::
MemType
::
Vram
,
idx
as
u64
),
StorageKind
::
System
=>
(
nixl
::
MemType
::
Dram
,
0
),
StorageKind
::
Pinned
=>
(
nixl
::
MemType
::
Dram
,
0
),
StorageKind
::
Disk
(
fd
)
=>
(
nixl
::
MemType
::
File
,
fd
),
};
(
self
.addr
()
as
*
const
u8
,
self
.size
(),
mem_type
,
device_id
)
}
}
impl
MemoryDescriptor
for
Arc
<
dyn
TensorDescriptor
>
{
fn
addr
(
&
self
)
->
usize
{
(
**
self
)
.addr
()
}
fn
size
(
&
self
)
->
usize
{
(
**
self
)
.size
()
}
fn
storage_kind
(
&
self
)
->
StorageKind
{
(
**
self
)
.storage_kind
()
}
fn
as_any
(
&
self
)
->
&
dyn
Any
{
self
}
fn
nixl_descriptor
(
&
self
)
->
Option
<
NixlDescriptor
>
{
None
}
}
impl
TensorDescriptor
for
Arc
<
dyn
TensorDescriptor
>
{
fn
shape
(
&
self
)
->
&
[
usize
]
{
(
**
self
)
.shape
()
}
fn
stride
(
&
self
)
->
&
[
usize
]
{
(
**
self
)
.stride
()
}
fn
element_size
(
&
self
)
->
usize
{
(
**
self
)
.element_size
()
}
}
// =============================================================================
// Arc<dyn TensorDescriptor + Send + Sync> support
// =============================================================================
impl
nixl
::
NixlCompatible
for
Arc
<
dyn
TensorDescriptor
+
Send
+
Sync
>
{
fn
nixl_params
(
&
self
)
->
(
*
const
u8
,
usize
,
nixl
::
MemType
,
u64
)
{
let
storage
=
self
.storage_kind
();
let
(
mem_type
,
device_id
)
=
match
storage
{
StorageKind
::
Device
(
idx
)
=>
(
nixl
::
MemType
::
Vram
,
idx
as
u64
),
StorageKind
::
System
=>
(
nixl
::
MemType
::
Dram
,
0
),
StorageKind
::
Pinned
=>
(
nixl
::
MemType
::
Dram
,
0
),
StorageKind
::
Disk
(
fd
)
=>
(
nixl
::
MemType
::
File
,
fd
),
};
(
self
.addr
()
as
*
const
u8
,
self
.size
(),
mem_type
,
device_id
)
}
}
impl
MemoryDescriptor
for
Arc
<
dyn
TensorDescriptor
+
Send
+
Sync
>
{
fn
addr
(
&
self
)
->
usize
{
(
**
self
)
.addr
()
}
fn
size
(
&
self
)
->
usize
{
(
**
self
)
.size
()
}
fn
storage_kind
(
&
self
)
->
StorageKind
{
(
**
self
)
.storage_kind
()
}
fn
as_any
(
&
self
)
->
&
dyn
Any
{
self
}
fn
nixl_descriptor
(
&
self
)
->
Option
<
NixlDescriptor
>
{
None
}
}
impl
TensorDescriptor
for
Arc
<
dyn
TensorDescriptor
+
Send
+
Sync
>
{
fn
shape
(
&
self
)
->
&
[
usize
]
{
(
**
self
)
.shape
()
}
fn
stride
(
&
self
)
->
&
[
usize
]
{
(
**
self
)
.stride
()
}
fn
element_size
(
&
self
)
->
usize
{
(
**
self
)
.element_size
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
/// Simple test tensor for unit tests
#[derive(Debug)]
struct
TestTensor
{
addr
:
usize
,
size
:
usize
,
shape
:
Vec
<
usize
>
,
stride
:
Vec
<
usize
>
,
element_size
:
usize
,
}
impl
MemoryDescriptor
for
TestTensor
{
fn
addr
(
&
self
)
->
usize
{
self
.addr
}
fn
size
(
&
self
)
->
usize
{
self
.size
}
fn
storage_kind
(
&
self
)
->
StorageKind
{
StorageKind
::
System
}
fn
as_any
(
&
self
)
->
&
dyn
Any
{
self
}
fn
nixl_descriptor
(
&
self
)
->
Option
<
NixlDescriptor
>
{
None
}
}
impl
TensorDescriptor
for
TestTensor
{
fn
shape
(
&
self
)
->
&
[
usize
]
{
&
self
.shape
}
fn
stride
(
&
self
)
->
&
[
usize
]
{
&
self
.stride
}
fn
element_size
(
&
self
)
->
usize
{
self
.element_size
}
}
#[test]
fn
test_numel
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
24
*
4
,
// 24 elements * 4 bytes
shape
:
vec!
[
2
,
3
,
4
],
stride
:
vec!
[
12
,
4
,
1
],
element_size
:
4
,
};
assert_eq!
(
tensor
.numel
(),
24
);
}
#[test]
fn
test_ndim
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
24
*
4
,
shape
:
vec!
[
2
,
3
,
4
],
stride
:
vec!
[
12
,
4
,
1
],
element_size
:
4
,
};
assert_eq!
(
tensor
.ndim
(),
3
);
}
#[test]
fn
test_is_contiguous_true
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
24
*
4
,
shape
:
vec!
[
2
,
3
,
4
],
stride
:
vec!
[
12
,
4
,
1
],
// Contiguous stride
element_size
:
4
,
};
assert
!
(
tensor
.is_contiguous
());
}
#[test]
fn
test_is_contiguous_false
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
24
*
4
,
shape
:
vec!
[
2
,
3
,
4
],
stride
:
vec!
[
24
,
4
,
1
],
// Non-contiguous (gap between first dim)
element_size
:
4
,
};
assert
!
(
!
tensor
.is_contiguous
());
}
#[test]
fn
test_contiguous_stride
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
24
*
4
,
shape
:
vec!
[
2
,
3
,
4
],
stride
:
vec!
[
24
,
4
,
1
],
// Non-contiguous
element_size
:
4
,
};
assert_eq!
(
tensor
.contiguous_stride
(),
vec!
[
12
,
4
,
1
]);
}
#[test]
fn
test_empty_tensor
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
0
,
shape
:
vec!
[],
stride
:
vec!
[],
element_size
:
4
,
};
assert_eq!
(
tensor
.numel
(),
1
);
// Empty product is 1
assert_eq!
(
tensor
.ndim
(),
0
);
assert
!
(
tensor
.is_contiguous
());
}
#[test]
fn
test_1d_tensor_contiguous
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
10
*
4
,
shape
:
vec!
[
10
],
stride
:
vec!
[
1
],
element_size
:
4
,
};
assert_eq!
(
tensor
.numel
(),
10
);
assert_eq!
(
tensor
.ndim
(),
1
);
assert
!
(
tensor
.is_contiguous
());
assert_eq!
(
tensor
.contiguous_stride
(),
vec!
[
1
]);
}
#[test]
fn
test_1d_tensor_non_contiguous
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
10
*
4
,
shape
:
vec!
[
10
],
stride
:
vec!
[
2
],
// Strided access (every other element)
element_size
:
4
,
};
assert
!
(
!
tensor
.is_contiguous
());
}
#[test]
fn
test_2d_tensor
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
6
*
4
,
shape
:
vec!
[
2
,
3
],
stride
:
vec!
[
3
,
1
],
element_size
:
4
,
};
assert_eq!
(
tensor
.numel
(),
6
);
assert_eq!
(
tensor
.ndim
(),
2
);
assert
!
(
tensor
.is_contiguous
());
}
#[test]
fn
test_high_dimensional_tensor
()
{
// 5D tensor: [2, 3, 4, 5, 6]
let
shape
=
vec!
[
2
,
3
,
4
,
5
,
6
];
// Contiguous stride: [360, 120, 30, 6, 1]
let
stride
=
vec!
[
360
,
120
,
30
,
6
,
1
];
let
numel
:
usize
=
shape
.iter
()
.product
();
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
numel
*
4
,
shape
,
stride
,
element_size
:
4
,
};
assert_eq!
(
tensor
.numel
(),
720
);
assert_eq!
(
tensor
.ndim
(),
5
);
assert
!
(
tensor
.is_contiguous
());
assert_eq!
(
tensor
.contiguous_stride
(),
vec!
[
360
,
120
,
30
,
6
,
1
]);
}
#[test]
fn
test_tensor_with_size_1_dimensions
()
{
// Shape with singleton dimensions: [1, 3, 1, 4]
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
12
*
4
,
shape
:
vec!
[
1
,
3
,
1
,
4
],
stride
:
vec!
[
12
,
4
,
4
,
1
],
// Contiguous for this shape
element_size
:
4
,
};
assert_eq!
(
tensor
.numel
(),
12
);
assert_eq!
(
tensor
.ndim
(),
4
);
assert
!
(
tensor
.is_contiguous
());
}
#[test]
fn
test_contiguous_stride_empty
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
0
,
shape
:
vec!
[],
stride
:
vec!
[],
element_size
:
4
,
};
assert
!
(
tensor
.contiguous_stride
()
.is_empty
());
}
#[test]
fn
test_contiguous_stride_1d
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
5
*
4
,
shape
:
vec!
[
5
],
stride
:
vec!
[
1
],
element_size
:
4
,
};
assert_eq!
(
tensor
.contiguous_stride
(),
vec!
[
1
]);
}
#[test]
fn
test_cuda_device_id_system
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
100
,
shape
:
vec!
[
10
],
stride
:
vec!
[
1
],
element_size
:
4
,
};
assert_eq!
(
tensor
.cuda_device_id
(),
None
);
}
/// Test tensor that reports Device storage kind
#[derive(Debug)]
struct
DeviceTensor
{
addr
:
usize
,
size
:
usize
,
shape
:
Vec
<
usize
>
,
stride
:
Vec
<
usize
>
,
element_size
:
usize
,
device_id
:
u32
,
}
impl
MemoryDescriptor
for
DeviceTensor
{
fn
addr
(
&
self
)
->
usize
{
self
.addr
}
fn
size
(
&
self
)
->
usize
{
self
.size
}
fn
storage_kind
(
&
self
)
->
StorageKind
{
StorageKind
::
Device
(
self
.device_id
)
}
fn
as_any
(
&
self
)
->
&
dyn
Any
{
self
}
fn
nixl_descriptor
(
&
self
)
->
Option
<
NixlDescriptor
>
{
None
}
}
impl
TensorDescriptor
for
DeviceTensor
{
fn
shape
(
&
self
)
->
&
[
usize
]
{
&
self
.shape
}
fn
stride
(
&
self
)
->
&
[
usize
]
{
&
self
.stride
}
fn
element_size
(
&
self
)
->
usize
{
self
.element_size
}
}
#[test]
fn
test_cuda_device_id_device
()
{
let
tensor
=
DeviceTensor
{
addr
:
0x1000
,
size
:
100
,
shape
:
vec!
[
10
],
stride
:
vec!
[
1
],
element_size
:
4
,
device_id
:
2
,
};
assert_eq!
(
tensor
.cuda_device_id
(),
Some
(
2
));
}
#[test]
fn
test_arc_tensor_descriptor
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
24
*
4
,
shape
:
vec!
[
2
,
3
,
4
],
stride
:
vec!
[
12
,
4
,
1
],
element_size
:
4
,
};
let
arc
:
Arc
<
dyn
TensorDescriptor
>
=
Arc
::
new
(
tensor
);
assert_eq!
(
arc
.addr
(),
0x1000
);
assert_eq!
(
arc
.size
(),
24
*
4
);
assert_eq!
(
arc
.shape
(),
&
[
2
,
3
,
4
]);
assert_eq!
(
arc
.stride
(),
&
[
12
,
4
,
1
]);
assert_eq!
(
arc
.element_size
(),
4
);
assert_eq!
(
arc
.storage_kind
(),
StorageKind
::
System
);
assert
!
(
arc
.nixl_descriptor
()
.is_none
());
}
#[test]
fn
test_arc_tensor_send_sync
()
{
// TestTensor doesn't impl Send+Sync, so we need a type that does
struct
SendSyncTensor
{
addr
:
usize
,
size
:
usize
,
shape
:
Vec
<
usize
>
,
stride
:
Vec
<
usize
>
,
element_size
:
usize
,
}
unsafe
impl
Send
for
SendSyncTensor
{}
unsafe
impl
Sync
for
SendSyncTensor
{}
impl
std
::
fmt
::
Debug
for
SendSyncTensor
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"SendSyncTensor"
)
.finish
()
}
}
impl
MemoryDescriptor
for
SendSyncTensor
{
fn
addr
(
&
self
)
->
usize
{
self
.addr
}
fn
size
(
&
self
)
->
usize
{
self
.size
}
fn
storage_kind
(
&
self
)
->
StorageKind
{
StorageKind
::
System
}
fn
as_any
(
&
self
)
->
&
dyn
Any
{
self
}
fn
nixl_descriptor
(
&
self
)
->
Option
<
NixlDescriptor
>
{
None
}
}
impl
TensorDescriptor
for
SendSyncTensor
{
fn
shape
(
&
self
)
->
&
[
usize
]
{
&
self
.shape
}
fn
stride
(
&
self
)
->
&
[
usize
]
{
&
self
.stride
}
fn
element_size
(
&
self
)
->
usize
{
self
.element_size
}
}
let
tensor
=
SendSyncTensor
{
addr
:
0x2000
,
size
:
100
,
shape
:
vec!
[
10
],
stride
:
vec!
[
1
],
element_size
:
4
,
};
let
arc
:
Arc
<
dyn
TensorDescriptor
+
Send
+
Sync
>
=
Arc
::
new
(
tensor
);
assert_eq!
(
arc
.addr
(),
0x2000
);
assert_eq!
(
arc
.size
(),
100
);
assert_eq!
(
arc
.shape
(),
&
[
10
]);
assert_eq!
(
arc
.stride
(),
&
[
1
]);
assert_eq!
(
arc
.element_size
(),
4
);
}
#[test]
fn
test_tensor_shape_stride_element_size
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
48
,
shape
:
vec!
[
3
,
4
],
stride
:
vec!
[
4
,
1
],
element_size
:
4
,
};
assert_eq!
(
tensor
.shape
(),
&
[
3
,
4
]);
assert_eq!
(
tensor
.stride
(),
&
[
4
,
1
]);
assert_eq!
(
tensor
.element_size
(),
4
);
}
#[test]
fn
test_tensor_numel_single_element
()
{
let
tensor
=
TestTensor
{
addr
:
0x1000
,
size
:
4
,
shape
:
vec!
[
1
,
1
,
1
],
stride
:
vec!
[
1
,
1
,
1
],
element_size
:
4
,
};
assert_eq!
(
tensor
.numel
(),
1
);
}
}
lib/memory/src/tests.rs
View file @
976bb70a
...
...
@@ -7,13 +7,13 @@ use super::*;
/// Helper function to validate NIXL descriptor consistency.
///
/// For any MemoryDescript
ion
that returns Some from nixl_descriptor(),
/// For any MemoryDescript
or
that returns Some from nixl_descriptor(),
/// this validates that the descriptor's addr and size match the memory region's addr and size.
///
/// # Panics
/// Panics if descriptor values don't match memory region values.
#[allow(dead_code)]
fn
validate_nixl_descriptor
<
M
:
MemoryDescript
ion
>
(
memory
:
&
M
)
{
fn
validate_nixl_descriptor
<
M
:
MemoryDescript
or
>
(
memory
:
&
M
)
{
if
let
Some
(
desc
)
=
memory
.nixl_descriptor
()
{
assert_eq!
(
desc
.addr
as
usize
,
...
...
@@ -32,6 +32,186 @@ fn validate_nixl_descriptor<M: MemoryDescription>(memory: &M) {
}
}
// ========== StorageKind tests ==========
#[test]
fn
test_storage_kind_cuda_device_index_device
()
{
let
kind
=
StorageKind
::
Device
(
3
);
assert_eq!
(
kind
.cuda_device_index
(),
Some
(
3
));
}
#[test]
fn
test_storage_kind_cuda_device_index_system
()
{
let
kind
=
StorageKind
::
System
;
assert_eq!
(
kind
.cuda_device_index
(),
None
);
}
#[test]
fn
test_storage_kind_cuda_device_index_pinned
()
{
let
kind
=
StorageKind
::
Pinned
;
assert_eq!
(
kind
.cuda_device_index
(),
None
);
}
#[test]
fn
test_storage_kind_cuda_device_index_disk
()
{
let
kind
=
StorageKind
::
Disk
(
123
);
assert_eq!
(
kind
.cuda_device_index
(),
None
);
}
#[test]
fn
test_storage_kind_is_cuda
()
{
assert
!
(
StorageKind
::
Device
(
0
)
.is_cuda
());
assert
!
(
!
StorageKind
::
System
.is_cuda
());
assert
!
(
!
StorageKind
::
Pinned
.is_cuda
());
assert
!
(
!
StorageKind
::
Disk
(
1
)
.is_cuda
());
}
#[test]
fn
test_storage_kind_is_system
()
{
assert
!
(
StorageKind
::
System
.is_system
());
assert
!
(
!
StorageKind
::
Device
(
0
)
.is_system
());
assert
!
(
!
StorageKind
::
Pinned
.is_system
());
assert
!
(
!
StorageKind
::
Disk
(
1
)
.is_system
());
}
#[test]
fn
test_storage_kind_is_pinned
()
{
assert
!
(
StorageKind
::
Pinned
.is_pinned
());
assert
!
(
!
StorageKind
::
System
.is_pinned
());
assert
!
(
!
StorageKind
::
Device
(
0
)
.is_pinned
());
assert
!
(
!
StorageKind
::
Disk
(
1
)
.is_pinned
());
}
#[test]
fn
test_storage_kind_is_disk
()
{
assert
!
(
StorageKind
::
Disk
(
1
)
.is_disk
());
assert
!
(
!
StorageKind
::
System
.is_disk
());
assert
!
(
!
StorageKind
::
Pinned
.is_disk
());
assert
!
(
!
StorageKind
::
Device
(
0
)
.is_disk
());
}
// ========== Buffer tests ==========
#[test]
fn
test_buffer_new
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
let
buffer
=
Buffer
::
new
(
storage
);
assert_eq!
(
buffer
.size
(),
1024
);
assert_eq!
(
buffer
.storage_kind
(),
StorageKind
::
System
);
}
#[test]
fn
test_buffer_from_arc
()
{
use
std
::
sync
::
Arc
;
let
storage
=
SystemStorage
::
new
(
2048
)
.unwrap
();
let
arc
:
Arc
<
dyn
MemoryDescriptor
>
=
Arc
::
new
(
storage
);
let
buffer
=
Buffer
::
from_arc
(
arc
);
assert_eq!
(
buffer
.size
(),
2048
);
}
#[test]
fn
test_buffer_from_impl
()
{
use
std
::
sync
::
Arc
;
let
storage
=
SystemStorage
::
new
(
512
)
.unwrap
();
let
arc
:
Arc
<
dyn
MemoryDescriptor
>
=
Arc
::
new
(
storage
);
let
buffer
:
Buffer
=
arc
.into
();
assert_eq!
(
buffer
.size
(),
512
);
}
#[test]
fn
test_buffer_deref
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
let
buffer
=
Buffer
::
new
(
storage
);
// Deref allows calling MemoryDescriptor methods directly
let
size
=
buffer
.size
();
assert_eq!
(
size
,
1024
);
}
#[test]
fn
test_buffer_debug
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
let
buffer
=
Buffer
::
new
(
storage
);
let
debug_str
=
format!
(
"{:?}"
,
buffer
);
assert
!
(
debug_str
.contains
(
"Buffer"
));
assert
!
(
debug_str
.contains
(
"size"
));
assert
!
(
debug_str
.contains
(
"addr"
));
}
#[test]
fn
test_buffer_clone
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
let
buffer
=
Buffer
::
new
(
storage
);
let
cloned
=
buffer
.clone
();
assert_eq!
(
buffer
.addr
(),
cloned
.addr
());
assert_eq!
(
buffer
.size
(),
cloned
.size
());
}
// ========== MemoryRegion tests ==========
#[test]
fn
test_memory_region_new
()
{
let
region
=
MemoryRegion
::
new
(
0x1000
,
4096
);
assert_eq!
(
region
.addr
,
0x1000
);
assert_eq!
(
region
.size
,
4096
);
}
#[test]
fn
test_memory_region_accessors
()
{
let
region
=
MemoryRegion
::
new
(
0x2000
,
8192
);
assert_eq!
(
region
.addr
(),
0x2000
);
assert_eq!
(
region
.size
(),
8192
);
}
#[test]
fn
test_memory_region_zero_address
()
{
let
region
=
MemoryRegion
::
new
(
0
,
1024
);
assert_eq!
(
region
.addr
(),
0
);
assert_eq!
(
region
.size
(),
1024
);
}
#[test]
fn
test_memory_region_zero_size
()
{
let
region
=
MemoryRegion
::
new
(
0x1000
,
0
);
assert_eq!
(
region
.addr
(),
0x1000
);
assert_eq!
(
region
.size
(),
0
);
}
#[test]
fn
test_memory_region_clone
()
{
let
region
=
MemoryRegion
::
new
(
0x3000
,
2048
);
let
cloned
=
region
;
assert_eq!
(
region
.addr
(),
cloned
.addr
());
assert_eq!
(
region
.size
(),
cloned
.size
());
}
#[test]
fn
test_memory_region_eq
()
{
let
region1
=
MemoryRegion
::
new
(
0x1000
,
4096
);
let
region2
=
MemoryRegion
::
new
(
0x1000
,
4096
);
let
region3
=
MemoryRegion
::
new
(
0x2000
,
4096
);
assert_eq!
(
region1
,
region2
);
assert_ne!
(
region1
,
region3
);
}
#[test]
fn
test_memory_region_debug
()
{
let
region
=
MemoryRegion
::
new
(
0x1000
,
4096
);
let
debug_str
=
format!
(
"{:?}"
,
region
);
assert
!
(
debug_str
.contains
(
"MemoryRegion"
));
}
// ========== create_buffer helper tests ==========
#[test]
fn
test_create_buffer_helper
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
let
buffer
=
create_buffer
(
storage
);
assert_eq!
(
buffer
.size
(),
1024
);
assert_eq!
(
buffer
.storage_kind
(),
StorageKind
::
System
);
}
// ========== Original tests ==========
#[test]
fn
test_system_storage
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
...
...
lib/memory/src/torch.rs
deleted
100644 → 0
View file @
57bdfea9
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#[derive(Clone,
Debug,
PartialEq,
Eq)]
pub
enum
TorchDevice
{
Cuda
(
usize
),
Other
(
String
),
}
impl
TorchDevice
{
pub
fn
is_cuda
(
&
self
)
->
bool
{
matches!
(
self
,
TorchDevice
::
Cuda
(
_
))
}
pub
fn
cuda_device_index
(
&
self
)
->
Option
<
usize
>
{
match
self
{
TorchDevice
::
Cuda
(
index
)
=>
Some
(
*
index
),
TorchDevice
::
Other
(
_
)
=>
None
,
}
}
}
pub
trait
TorchTensor
:
std
::
fmt
::
Debug
+
Send
+
Sync
{
fn
device
(
&
self
)
->
TorchDevice
;
fn
data_ptr
(
&
self
)
->
u64
;
fn
size_bytes
(
&
self
)
->
usize
;
fn
shape
(
&
self
)
->
Vec
<
usize
>
;
fn
stride
(
&
self
)
->
Vec
<
usize
>
;
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment