Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f0652d89
Unverified
Commit
f0652d89
authored
Jul 01, 2025
by
Yan Ru Pei
Committed by
GitHub
Jul 01, 2025
Browse files
feat: vllm mocker enhancement (#1236)
parent
0d6cae85
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1654 additions
and
410 deletions
+1654
-410
lib/llm/src/mocker.rs
lib/llm/src/mocker.rs
+1
-0
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+764
-0
lib/llm/src/mocker/evictor.rs
lib/llm/src/mocker/evictor.rs
+96
-105
lib/llm/src/mocker/kv_manager.rs
lib/llm/src/mocker/kv_manager.rs
+166
-70
lib/llm/src/mocker/protocols.rs
lib/llm/src/mocker/protocols.rs
+86
-4
lib/llm/src/mocker/scheduler.rs
lib/llm/src/mocker/scheduler.rs
+424
-172
lib/llm/src/mocker/sequence.rs
lib/llm/src/mocker/sequence.rs
+117
-59
No files found.
lib/llm/src/mocker.rs
View file @
f0652d89
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
pub
mod
engine
;
pub
mod
evictor
;
pub
mod
evictor
;
pub
mod
kv_manager
;
pub
mod
kv_manager
;
pub
mod
protocols
;
pub
mod
protocols
;
...
...
lib/llm/src/mocker/engine.rs
0 → 100644
View file @
f0652d89
This diff is collapsed.
Click to expand it.
lib/llm/src/mocker/evictor.rs
View file @
f0652d89
...
@@ -13,167 +13,158 @@
...
@@ -13,167 +13,158 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
std
::
cmp
::
Eq
;
use
std
::
cmp
::
{
Eq
,
Ordering
}
;
use
std
::
collections
::{
HashMap
,
VecDeque
};
use
std
::
collections
::{
BTreeSet
,
HashMap
};
use
std
::
hash
::
Hash
;
use
std
::
hash
::
Hash
;
use
std
::
time
::
Instant
;
/// A wrapper for (T, counter) that implements Ord based only on counter
#[derive(Debug,
Clone,
Eq,
PartialEq)]
struct
PriorityItem
<
T
>
{
item
:
T
,
counter
:
i64
,
}
impl
<
T
:
Eq
>
Ord
for
PriorityItem
<
T
>
{
fn
cmp
(
&
self
,
other
:
&
Self
)
->
Ordering
{
self
.counter
.cmp
(
&
other
.counter
)
}
}
impl
<
T
:
Eq
>
PartialOrd
for
PriorityItem
<
T
>
{
fn
partial_cmp
(
&
self
,
other
:
&
Self
)
->
Option
<
Ordering
>
{
Some
(
self
.cmp
(
other
))
}
}
/// An LRU evictor that maintains objects and evicts them based on their
/// An LRU evictor that maintains objects and evicts them based on their
/// last accessed time. Implements a "lazy" eviction mechanism where:
/// priority counter. Lower counter values are evicted first.
/// 1. The priority queue does not immediately reflect updates or removes
/// 2. Objects are pushed to the queue in order of increasing priority (older objects first)
/// 3. The user must ensure objects are added in correct priority (temporal order)
/// 4. Remove and update operations are lazy - entries remain in the queue until
/// they are either evicted or cleaned up during maintenance
#[derive(Debug)]
#[derive(Debug)]
pub
struct
LRUEvictor
<
T
:
Clone
+
Eq
+
Hash
>
{
pub
struct
LRUEvictor
<
T
:
Clone
+
Eq
+
Hash
>
{
free_table
:
HashMap
<
T
,
f
64
>
,
free_table
:
HashMap
<
T
,
i
64
>
,
priority_queue
:
VecDeque
<
(
T
,
f64
)
>
,
priority_queue
:
BTreeSet
<
PriorityItem
<
T
>
>
,
cleanup_threshold
:
usize
,
positive_counter
:
i64
,
start_time
:
Instant
,
negative_counter
:
i64
,
}
}
impl
<
T
:
Clone
+
Eq
+
Hash
>
Default
for
LRUEvictor
<
T
>
{
impl
<
T
:
Clone
+
Eq
+
Hash
>
Default
for
LRUEvictor
<
T
>
{
fn
default
()
->
Self
{
fn
default
()
->
Self
{
Self
{
Self
{
free_table
:
HashMap
::
new
(),
free_table
:
HashMap
::
new
(),
priority_queue
:
VecDeque
::
new
(),
priority_queue
:
BTreeSet
::
new
(),
cleanup_threshold
:
5
0
,
positive_counter
:
0
,
start_time
:
Instant
::
now
()
,
negative_counter
:
0
,
}
}
}
}
}
}
impl
<
T
:
Clone
+
Eq
+
Hash
>
LRUEvictor
<
T
>
{
impl
<
T
:
Clone
+
Eq
+
Hash
>
LRUEvictor
<
T
>
{
/// Create a new LRUEvictor with the default cleanup threshold
pub
fn
new
(
_
cleanup_threshold
:
usize
)
->
Self
{
pub
fn
new
(
cleanup_threshold
:
usize
)
->
Self
{
Self
::
default
()
Self
{
cleanup_threshold
,
..
Default
::
default
()
}
}
}
/// Get the current timestamp as seconds since initialization
pub
fn
keys
(
&
self
)
->
std
::
collections
::
hash_map
::
Keys
<
'_
,
T
,
i64
>
{
pub
fn
current_timestamp
(
&
self
)
->
f64
{
self
.free_table
.keys
()
self
.start_time
.elapsed
()
.as_secs_f64
()
}
}
/// Get an iterator over the keys in the evictor
fn
update
(
&
mut
self
,
object
:
T
,
counter
:
i64
)
{
pub
fn
keys
(
&
self
)
->
std
::
collections
::
hash_map
::
Keys
<
'_
,
T
,
f64
>
{
self
.free_table
.insert
(
object
.clone
(),
counter
);
self
.free_table
.keys
()
self
.priority_queue
.insert
(
PriorityItem
{
item
:
object
,
counter
,
});
}
}
/// Insert or update an object in the evictor with current timestamp
pub
fn
insert
(
&
mut
self
,
object
:
T
)
{
pub
fn
insert
(
&
mut
self
,
object
:
T
)
{
let
timestamp
=
self
.current_timestamp
();
// Remove old entry if it exists
self
._insert
(
object
,
timestamp
);
if
let
Some
(
&
old_counter
)
=
self
.free_table
.get
(
&
object
)
{
}
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
:
old_counter
,
});
}
/// Check if the evictor contains the given object
// Increment positive counter and insert
pub
fn
contains
(
&
self
,
object
:
&
T
)
->
bool
{
self
.positive_counter
+=
1
;
self
.free_table
.contains_key
(
object
)
let
counter
=
self
.positive_counter
;
self
.update
(
object
,
counter
);
}
}
/// Evict an object based on LRU policy
/// Push an object to the front with negative counter (highest priority for eviction)
/// Returns the evicted object or None if no objects are available
pub
fn
push_front
(
&
mut
self
,
object
:
T
)
{
pub
fn
evict
(
&
mut
self
)
->
Option
<
T
>
{
// Remove old entry if it exists
if
self
.free_table
.is_empty
()
{
if
let
Some
(
&
old_counter
)
=
self
.free_table
.get
(
&
object
)
{
return
None
;
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
:
old_counter
,
});
}
}
while
let
Some
((
object
,
last_accessed
))
=
self
.priority_queue
.pop_front
()
{
// Decrement negative counter and insert
let
Some
(
&
current_last_accessed
)
=
self
.free_table
.get
(
&
object
)
else
{
self
.negative_counter
-=
1
;
continue
;
// entry is already removed
let
counter
=
self
.negative_counter
;
};
if
current_last_accessed
==
last_accessed
{
self
.update
(
object
,
counter
);
self
.free_table
.remove
(
&
object
);
}
return
Some
(
object
);
}
// otherwise entry is stale
}
None
pub
fn
contains
(
&
self
,
object
:
&
T
)
->
bool
{
self
.free_table
.contains_key
(
object
)
}
}
/// Insert or update an object in the evictor
/// Evict an object based on LRU policy (lowest counter value)
fn
_
insert
(
&
mut
self
,
object
:
T
,
last_accessed
:
f64
)
{
/// Returns the evicted object or None if no objects are available
self
.free_table
.insert
(
object
.clone
(),
last_accessed
);
pub
fn
evict
(
&
mut
self
)
->
Option
<
T
>
{
self
.priority_queue
.push_back
((
object
,
last_accessed
));
self
.priority_queue
.pop_first
()
.map
(|
item
|
{
self
.cleanup_if_necessary
();
self
.free_table
.remove
(
&
item
.item
);
item
.item
})
}
}
/// Remove an object from the evictor
/// We don't remove from the priority queue immediately, as that would be inefficient
/// Outdated entries will be filtered out during eviction or cleanup
pub
fn
remove
(
&
mut
self
,
object
:
&
T
)
->
bool
{
pub
fn
remove
(
&
mut
self
,
object
:
&
T
)
->
bool
{
self
.free_table
.remove
(
object
)
.is_some
()
let
Some
(
&
counter
)
=
self
.free_table
.get
(
object
)
else
{
return
false
;
};
self
.free_table
.remove
(
object
);
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
,
});
true
}
}
/// Get the number of objects in the evictor
pub
fn
len
(
&
self
)
->
usize
{
pub
fn
len
(
&
self
)
->
usize
{
self
.free_table
.len
()
self
.free_table
.len
()
}
}
/// Check if the evictor is empty
pub
fn
is_empty
(
&
self
)
->
bool
{
pub
fn
is_empty
(
&
self
)
->
bool
{
self
.free_table
.is_empty
()
self
.free_table
.is_empty
()
}
}
/// Check if cleanup is necessary and perform it if needed
fn
cleanup_if_necessary
(
&
mut
self
)
{
if
self
.priority_queue
.len
()
>
self
.cleanup_threshold
*
self
.free_table
.len
()
{
self
.cleanup
();
}
}
/// Clean up the priority queue by removing outdated entries
fn
cleanup
(
&
mut
self
)
{
let
mut
new_priority_queue
=
VecDeque
::
new
();
for
(
object
,
timestamp
)
in
self
.priority_queue
.drain
(
..
)
{
let
Some
(
&
current_timestamp
)
=
self
.free_table
.get
(
&
object
)
else
{
continue
;
};
if
current_timestamp
==
timestamp
{
new_priority_queue
.push_back
((
object
,
timestamp
));
}
}
self
.priority_queue
=
new_priority_queue
;
}
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
use
rstest
::
rstest
;
#[rstest]
#[test]
#[case(
1
)]
fn
test_lru_evictor_eviction_order
()
{
#[case(
2
)]
// Create a new LRUEvictor
#[case(
3
)]
let
mut
evictor
=
LRUEvictor
::
<
i32
>
::
new
(
1
);
// threshold value doesn't matter anymore
fn
test_lru_evictor_eviction_order
(
#[case]
threshold
:
usize
)
{
// Create a new LRUEvictor with the given cleanup threshold
let
mut
evictor
=
LRUEvictor
::
<
i32
>
::
new
(
threshold
);
// Add items in the specified order
with small delays between each
// Add items in the specified order
evictor
.insert
(
4
);
evictor
.insert
(
4
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
3
);
evictor
.insert
(
3
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
2
);
evictor
.insert
(
2
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
1
);
evictor
.insert
(
1
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
5
);
evictor
.insert
(
5
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
1
);
// Updates counter for 1
evictor
.insert
(
1
);
// Updates timestamp for 1
evictor
.insert
(
4
);
// Updates counter for 4
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
2
);
// Updates counter for 2
evictor
.insert
(
4
);
// Updates timestamp for 4
evictor
.push_front
(
4
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
2
);
// Updates timestamp for 2
// Verify the eviction order
// Verify the eviction order
println!
(
"Testing with threshold {}"
,
threshold
);
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
4
);
let
evicted
=
evictor
.evict
()
.unwrap
();
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
3
);
assert_eq!
(
evicted
,
3
);
let
evicted
=
evictor
.evict
()
.unwrap
();
let
evicted
=
evictor
.evict
()
.unwrap
();
...
@@ -181,11 +172,11 @@ mod tests {
...
@@ -181,11 +172,11 @@ mod tests {
let
evicted
=
evictor
.evict
()
.unwrap
();
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
1
);
assert_eq!
(
evicted
,
1
);
let
evicted
=
evictor
.evict
()
.unwrap
();
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
4
);
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
2
);
assert_eq!
(
evicted
,
2
);
let
evicted
=
evictor
.evict
();
let
evicted
=
evictor
.evict
();
assert_eq!
(
evicted
,
None
);
assert_eq!
(
evicted
,
None
);
assert_eq!
(
evictor
.len
(),
0
);
assert_eq!
(
evictor
.len
(),
0
);
}
}
// ... existing test_push_front test ...
}
}
lib/llm/src/mocker/kv_manager.rs
View file @
f0652d89
...
@@ -46,10 +46,11 @@
...
@@ -46,10 +46,11 @@
//! implementation of the main block manager.
//! implementation of the main block manager.
use
crate
::
mocker
::
evictor
::
LRUEvictor
;
use
crate
::
mocker
::
evictor
::
LRUEvictor
;
use
crate
::
mocker
::
protocols
::{
MoveBlock
,
PrefillCost
,
UniqueBlock
};
use
crate
::
mocker
::
protocols
::{
MoveBlock
,
MoveBlockResponse
,
PrefillCost
,
UniqueBlock
};
use
crate
::
mocker
::
sequence
::
ActiveSequence
;
use
crate
::
mocker
::
sequence
::
ActiveSequence
;
use
derive_getters
::
Getters
;
use
derive_getters
::
Getters
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
collections
::{
HashMap
,
HashSet
};
use
tokio
::
sync
::
mpsc
;
#[derive(Getters)]
#[derive(Getters)]
pub
struct
KvManager
{
pub
struct
KvManager
{
...
@@ -57,17 +58,27 @@ pub struct KvManager {
...
@@ -57,17 +58,27 @@ pub struct KvManager {
max_capacity
:
usize
,
max_capacity
:
usize
,
#[getter(copy)]
#[getter(copy)]
block_size
:
u
32
,
block_size
:
u
size
,
active_blocks
:
HashMap
<
UniqueBlock
,
usize
>
,
active_blocks
:
HashMap
<
UniqueBlock
,
usize
>
,
inactive_blocks
:
LRUEvictor
<
UniqueBlock
>
,
inactive_blocks
:
LRUEvictor
<
UniqueBlock
>
,
all_blocks
:
HashSet
<
UniqueBlock
>
,
all_blocks
:
HashSet
<
UniqueBlock
>
,
move_block_response_tx
:
Option
<
mpsc
::
UnboundedSender
<
MoveBlockResponse
>>
,
}
}
impl
KvManager
{
impl
KvManager
{
pub
fn
new
(
max_capacity
:
usize
,
block_size
:
u32
)
->
Self
{
pub
fn
new
(
max_capacity
:
usize
,
block_size
:
usize
)
->
Self
{
Self
::
new_with_sender
(
max_capacity
,
block_size
,
None
)
}
pub
fn
new_with_sender
(
max_capacity
:
usize
,
block_size
:
usize
,
move_block_response_tx
:
Option
<
mpsc
::
UnboundedSender
<
MoveBlockResponse
>>
,
)
->
Self
{
let
active_blocks
=
HashMap
::
new
();
let
active_blocks
=
HashMap
::
new
();
let
inactive_blocks
=
LRUEvictor
::
default
();
let
inactive_blocks
=
LRUEvictor
::
default
();
let
all_blocks
=
HashSet
::
new
();
let
all_blocks
=
HashSet
::
new
();
...
@@ -78,18 +89,46 @@ impl KvManager {
...
@@ -78,18 +89,46 @@ impl KvManager {
active_blocks
,
active_blocks
,
inactive_blocks
,
inactive_blocks
,
all_blocks
,
all_blocks
,
move_block_response_tx
,
}
}
/// Utility method to send block responses with optional reversing
fn
send_block_response
(
&
self
,
mut
blocks
:
Vec
<
u64
>
,
reverse
:
bool
,
store
:
bool
,
parent_hash
:
Option
<
u64
>
,
)
{
if
let
Some
(
ref
tx
)
=
self
.move_block_response_tx
{
if
!
blocks
.is_empty
()
{
if
reverse
{
blocks
.reverse
();
}
let
response
=
if
store
{
MoveBlockResponse
::
Store
(
blocks
,
parent_hash
)
}
else
{
MoveBlockResponse
::
Remove
(
blocks
)
};
tx
.send
(
response
)
.unwrap
();
}
}
}
}
}
/// Process a MoveBlock instruction synchronously
/// Process a MoveBlock instruction synchronously
pub
fn
process
(
&
mut
self
,
event
:
&
MoveBlock
)
->
bool
{
pub
fn
process
(
&
mut
self
,
event
:
&
MoveBlock
)
->
bool
{
match
event
{
match
event
{
MoveBlock
::
Use
(
hashes
,
_
)
=>
{
MoveBlock
::
Use
(
hashes
)
=>
{
let
mut
blocks_stored
=
Vec
::
<
u64
>
::
new
();
let
mut
parent_block
:
Option
<&
UniqueBlock
>
=
None
;
for
hash
in
hashes
{
for
hash
in
hashes
{
// First check if it already exists in active blocks
// First check if it already exists in active blocks
if
let
Some
(
ref_count
)
=
self
.active_blocks
.get_mut
(
hash
)
{
if
let
Some
(
ref_count
)
=
self
.active_blocks
.get_mut
(
hash
)
{
// Block already active, just increment reference count
// Block already active, just increment reference count
*
ref_count
+=
1
;
*
ref_count
+=
1
;
parent_block
=
Some
(
hash
);
continue
;
continue
;
}
}
...
@@ -97,6 +136,7 @@ impl KvManager {
...
@@ -97,6 +136,7 @@ impl KvManager {
if
self
.inactive_blocks
.remove
(
hash
)
{
if
self
.inactive_blocks
.remove
(
hash
)
{
// Insert into active with reference count 1
// Insert into active with reference count 1
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
parent_block
=
Some
(
hash
);
continue
;
continue
;
}
}
...
@@ -106,30 +146,53 @@ impl KvManager {
...
@@ -106,30 +146,53 @@ impl KvManager {
// If at max capacity, evict the oldest entry from inactive blocks
// If at max capacity, evict the oldest entry from inactive blocks
if
active_count
+
inactive_count
>=
self
.max_capacity
{
if
active_count
+
inactive_count
>=
self
.max_capacity
{
if
let
Some
(
evicted
)
=
self
.inactive_blocks
.evict
()
{
let
Some
(
evicted
)
=
self
.inactive_blocks
.evict
()
else
{
// Remove evicted block from all_blocks
self
.all_blocks
.remove
(
&
evicted
);
}
else
{
// Cannot evict block, meaning no free blocks left in inactive pool
// Send a signal, scheduler would expect to handle preemption upon receiving this
return
false
;
return
false
;
};
self
.all_blocks
.remove
(
&
evicted
);
if
let
UniqueBlock
::
FullBlock
(
evicted_full_block
)
=
evicted
{
self
.send_block_response
(
vec!
[
evicted_full_block
],
false
,
false
,
None
);
}
}
}
}
// Now insert the new block in active blocks with reference count 1
// Now insert the new block in active blocks with reference count 1
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
// Add to all_blocks as it's a new block
self
.all_blocks
.insert
(
hash
.clone
());
self
.all_blocks
.insert
(
hash
.clone
());
if
self
.move_block_response_tx
.is_some
()
{
if
let
UniqueBlock
::
FullBlock
(
stored_full_block
)
=
hash
{
blocks_stored
.push
(
*
stored_full_block
);
}
}
}
}
let
parent_hash
=
match
parent_block
{
None
=>
None
,
Some
(
UniqueBlock
::
FullBlock
(
block
))
=>
Some
(
*
block
),
Some
(
UniqueBlock
::
PartialBlock
(
_
))
=>
panic!
(
"parent block cannot be partial"
),
};
self
.send_block_response
(
blocks_stored
,
false
,
true
,
parent_hash
);
}
}
MoveBlock
::
Destroy
(
hashes
)
=>
{
MoveBlock
::
Destroy
(
hashes
)
=>
{
let
mut
blocks_destroyed
=
Vec
::
<
u64
>
::
new
();
// Loop in inverse direction
// Loop in inverse direction
for
hash
in
hashes
.iter
()
.rev
()
{
for
hash
in
hashes
.iter
()
.rev
()
{
self
.active_blocks
.remove
(
hash
)
.unwrap
();
self
.active_blocks
.remove
(
hash
)
.unwrap
();
// Remove from all_blocks when destroyed
// Remove from all_blocks when destroyed
assert
!
(
self
.all_blocks
.remove
(
hash
));
assert
!
(
self
.all_blocks
.remove
(
hash
));
// Track blocks for batch sending
if
self
.move_block_response_tx
.is_some
()
{
if
let
UniqueBlock
::
FullBlock
(
destroyed_full_block
)
=
hash
{
blocks_destroyed
.push
(
*
destroyed_full_block
);
}
}
}
}
self
.send_block_response
(
blocks_destroyed
,
true
,
false
,
None
);
}
}
MoveBlock
::
Deref
(
hashes
)
=>
{
MoveBlock
::
Deref
(
hashes
)
=>
{
// Loop in inverse direction
// Loop in inverse direction
for
hash
in
hashes
.iter
()
.rev
()
{
for
hash
in
hashes
.iter
()
.rev
()
{
...
@@ -149,15 +212,15 @@ impl KvManager {
...
@@ -149,15 +212,15 @@ impl KvManager {
}
}
}
}
}
}
MoveBlock
::
Promote
(
uuid
,
hash
)
=>
{
MoveBlock
::
Promote
(
uuid
,
hash
,
parent_hash
)
=>
{
let
uuid_block
=
UniqueBlock
::
PartialBlock
(
*
uuid
);
let
uuid_block
=
UniqueBlock
::
PartialBlock
(
*
uuid
);
let
hash_block
=
UniqueBlock
::
FullBlock
(
*
hash
);
let
hash_block
=
UniqueBlock
::
FullBlock
(
*
hash
);
let
Some
(
ref_count
)
=
self
.active_blocks
.remove
(
&
uuid_block
)
else
{
let
Some
(
ref_count
)
=
self
.active_blocks
.remove
(
&
uuid_block
)
else
{
let
in_all_blocks
=
self
.all_blocks
.contains
(
&
uuid_block
);
let
in_all_blocks
=
self
.all_blocks
.contains
(
&
uuid_block
);
panic!
(
panic!
(
"Missing active block for promotion: {:?}. Block still exists: {}"
,
"Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}"
uuid_block
,
in_all_blocks
);
);
};
};
...
@@ -167,6 +230,7 @@ impl KvManager {
...
@@ -167,6 +230,7 @@ impl KvManager {
// Update all_blocks
// Update all_blocks
assert
!
(
self
.all_blocks
.remove
(
&
uuid_block
));
assert
!
(
self
.all_blocks
.remove
(
&
uuid_block
));
self
.all_blocks
.insert
(
hash_block
);
self
.all_blocks
.insert
(
hash_block
);
self
.send_block_response
(
vec!
[
*
hash
],
false
,
true
,
*
parent_hash
);
}
}
}
}
...
@@ -178,6 +242,7 @@ impl KvManager {
...
@@ -178,6 +242,7 @@ impl KvManager {
pub
fn
probe_new_blocks
(
&
self
,
blocks
:
&
[
UniqueBlock
])
->
usize
{
pub
fn
probe_new_blocks
(
&
self
,
blocks
:
&
[
UniqueBlock
])
->
usize
{
blocks
blocks
.iter
()
.iter
()
// .filter(|&block| !self.active_blocks.contains_key(block))
.filter
(|
&
block
|
!
self
.all_blocks
.contains
(
block
))
.filter
(|
&
block
|
!
self
.all_blocks
.contains
(
block
))
.count
()
.count
()
}
}
...
@@ -200,6 +265,11 @@ impl KvManager {
...
@@ -200,6 +265,11 @@ impl KvManager {
self
.active_blocks
.len
()
self
.active_blocks
.len
()
}
}
/// Get the percentage of active blocks relative to maximum capacity
pub
fn
get_active_perc
(
&
self
)
->
f64
{
self
.active_blocks
.len
()
as
f64
/
self
.max_capacity
as
f64
}
/// Get the number of inactive blocks
/// Get the number of inactive blocks
pub
fn
num_inactive_blocks
(
&
self
)
->
usize
{
pub
fn
num_inactive_blocks
(
&
self
)
->
usize
{
self
.inactive_blocks
.len
()
self
.inactive_blocks
.len
()
...
@@ -216,63 +286,28 @@ impl KvManager {
...
@@ -216,63 +286,28 @@ impl KvManager {
}
}
/// Check if a sequence can be scheduled and calculate cost if possible
/// Check if a sequence can be scheduled and calculate cost if possible
pub
fn
try_schedule
(
pub
fn
get_prefill_cost
(
&
self
,
sequence
:
&
ActiveSequence
)
->
PrefillCost
{
&
self
,
let
seq_blocks
=
sequence
.unique_blocks
();
sequence
:
&
ActiveSequence
,
let
new_blocks
=
self
.probe_new_blocks
(
seq_blocks
);
watermark
:
f64
,
let
overlap_blocks
=
seq_blocks
.len
()
-
new_blocks
;
tokens_budget
:
usize
,
let
new_tokens
=
sequence
.num_input_tokens
()
-
overlap_blocks
*
self
.block_size
;
)
->
Option
<
PrefillCost
>
{
// Return None immediately if tokens_budget is 0
if
tokens_budget
==
0
{
return
None
;
}
// Get unique blocks from the sequence
let
unique_blocks
=
sequence
.unique_blocks
();
// Get the count of new blocks
let
new_blocks
=
self
.probe_new_blocks
(
unique_blocks
);
// Calculate current usage and available capacity
let
active_count
=
self
.active_blocks
.len
();
// Check if we can schedule based on the watermark
if
(
active_count
+
new_blocks
)
as
f64
>
(
1.0
-
watermark
)
*
self
.max_capacity
as
f64
{
return
None
;
}
// Calculate overlap blocks
let
overlap_blocks
=
unique_blocks
.len
()
-
new_blocks
;
// Calculate new tokens
let
new_tokens
=
sequence
.num_input_tokens
()
-
overlap_blocks
*
(
self
.block_size
as
usize
);
// // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
// new_tokens,
// sequence.num_input_tokens(),
// overlap_blocks,
// self.block_size);
// Return None if new_tokens exceeds tokens_budget
if
new_tokens
>
tokens_budget
{
return
None
;
}
// Calculate prefill compute
// Calculate prefill compute
let
prefill_compute
=
let
prefill_compute
=
new_tokens
as
f64
*
(
new_tokens
+
overlap_blocks
*
(
self
.block_size
as
usize
))
as
f64
;
1.25e-6
*
(
new_tokens
as
f64
)
.powi
(
2
)
+
7.41e-2
*
(
new_tokens
as
f64
)
+
2.62e1
;
Some
(
PrefillCost
{
PrefillCost
{
new_blocks
,
new_tokens
,
new_tokens
,
prefill_compute
,
prefill_compute
,
}
)
}
}
}
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
use
tokio
::
sync
::
mpsc
;
#[test]
#[test]
fn
test_failure_on_max_capacity
()
{
fn
test_failure_on_max_capacity
()
{
...
@@ -282,7 +317,7 @@ mod tests {
...
@@ -282,7 +317,7 @@ mod tests {
// Helper function to use multiple blocks that returns the response
// Helper function to use multiple blocks that returns the response
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
->
bool
{
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
->
bool
{
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
manager
.process
(
&
MoveBlock
::
Use
(
blocks
,
None
))
manager
.process
(
&
MoveBlock
::
Use
(
blocks
))
}
}
// First use 10 blocks (0 to 9) in a batch
// First use 10 blocks (0 to 9) in a batch
...
@@ -301,15 +336,17 @@ mod tests {
...
@@ -301,15 +336,17 @@ mod tests {
}
}
#[test]
#[test]
// This is taken directly from the example in the vllm v1 prefix caching docs
fn
test_block_lifecycle_stringent
()
{
fn
test_block_lifecycle_stringent
()
{
// Create a KvManager with 10 blocks capacity
// Create a channel to listen to block responses
let
mut
manager
=
KvManager
::
new
(
10
,
16
);
let
(
tx
,
mut
rx
)
=
mpsc
::
unbounded_channel
::
<
MoveBlockResponse
>
();
// Create a KvManager with 10 blocks capacity and the response sender
let
mut
manager
=
KvManager
::
new_with_sender
(
10
,
16
,
Some
(
tx
));
// Helper function to use multiple blocks
// Helper function to use multiple blocks
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
{
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
{
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
manager
.process
(
&
MoveBlock
::
Use
(
blocks
,
None
));
manager
.process
(
&
MoveBlock
::
Use
(
blocks
));
}
}
// Helper function to destroy multiple blocks
// Helper function to destroy multiple blocks
...
@@ -324,6 +361,56 @@ mod tests {
...
@@ -324,6 +361,56 @@ mod tests {
manager
.process
(
&
MoveBlock
::
Deref
(
blocks
));
manager
.process
(
&
MoveBlock
::
Deref
(
blocks
));
}
}
// Helper function to assert block responses
fn
assert_block_response
(
rx
:
&
mut
mpsc
::
UnboundedReceiver
<
MoveBlockResponse
>
,
expected_type
:
&
str
,
expected_blocks
:
Vec
<
u64
>
,
description
:
&
str
,
)
{
let
response
=
rx
.try_recv
()
.unwrap_or_else
(|
_
|
panic!
(
"Expected {expected_type} response {description}"
));
match
(
&
response
,
expected_type
)
{
(
MoveBlockResponse
::
Store
(
blocks
,
_
parent_hash
),
"Store"
)
=>
{
assert_eq!
(
blocks
.len
(),
expected_blocks
.len
(),
"Expected {} blocks in Store response {}"
,
expected_blocks
.len
(),
description
);
assert_eq!
(
*
blocks
,
expected_blocks
,
"Store blocks don't match expected {description}"
);
}
(
MoveBlockResponse
::
Remove
(
blocks
),
"Remove"
)
=>
{
assert_eq!
(
blocks
.len
(),
expected_blocks
.len
(),
"Expected {} blocks in Remove response {}"
,
expected_blocks
.len
(),
description
);
assert_eq!
(
*
blocks
,
expected_blocks
,
"Remove blocks don't match expected {description}"
);
}
_
=>
panic!
(
"Expected {expected_type} response, got {response:?} {description}"
),
}
}
// Helper function to assert no response is received
fn
assert_no_response
(
rx
:
&
mut
mpsc
::
UnboundedReceiver
<
MoveBlockResponse
>
,
description
:
&
str
,
)
{
assert
!
(
rx
.try_recv
()
.is_err
(),
"Expected no response {description}"
,);
}
// Helper function to check if active blocks contain expected blocks with expected ref counts
// Helper function to check if active blocks contain expected blocks with expected ref counts
fn
assert_active_blocks
(
manager
:
&
KvManager
,
expected_blocks
:
&
[(
u64
,
usize
)])
{
fn
assert_active_blocks
(
manager
:
&
KvManager
,
expected_blocks
:
&
[(
u64
,
usize
)])
{
assert_eq!
(
assert_eq!
(
...
@@ -336,14 +423,12 @@ mod tests {
...
@@ -336,14 +423,12 @@ mod tests {
let
block
=
UniqueBlock
::
FullBlock
(
id
);
let
block
=
UniqueBlock
::
FullBlock
(
id
);
assert
!
(
assert
!
(
manager
.active_blocks
()
.contains_key
(
&
block
),
manager
.active_blocks
()
.contains_key
(
&
block
),
"Block {} not found in active blocks"
,
"Block {id} not found in active blocks"
,
id
);
);
assert_eq!
(
assert_eq!
(
manager
.active_blocks
()
.get
(
&
block
),
manager
.active_blocks
()
.get
(
&
block
),
Some
(
&
ref_count
),
Some
(
&
ref_count
),
"Block {} has wrong reference count"
,
"Block {id} has wrong reference count"
,
id
);
);
}
}
}
}
...
@@ -366,17 +451,18 @@ mod tests {
...
@@ -366,17 +451,18 @@ mod tests {
let
block
=
UniqueBlock
::
FullBlock
(
id
);
let
block
=
UniqueBlock
::
FullBlock
(
id
);
assert
!
(
assert
!
(
inactive_blocks
.iter
()
.any
(|
&
b
|
*
b
==
block
),
inactive_blocks
.iter
()
.any
(|
&
b
|
*
b
==
block
),
"Block {} not found in inactive blocks"
,
"Block {id} not found in inactive blocks"
,
id
);
);
}
}
}
}
// First use blocks 0, 1, 2, 3, 4 in a batch
// First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks
(
&
mut
manager
,
(
0
..
5
)
.collect
());
use_blocks
(
&
mut
manager
,
(
0
..
5
)
.collect
());
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
0
,
1
,
2
,
3
,
4
],
"after first use"
);
// Then use blocks 0, 1, 5, 6 in a batch
// Then use blocks 0, 1, 5, 6 in a batch
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
,
6
]);
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
,
6
]);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
5
,
6
],
"after second use"
);
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks
(
assert_active_blocks
(
...
@@ -386,9 +472,11 @@ mod tests {
...
@@ -386,9 +472,11 @@ mod tests {
// Now destroy block 4
// Now destroy block 4
destroy_blocks
(
&
mut
manager
,
vec!
[
4
]);
destroy_blocks
(
&
mut
manager
,
vec!
[
4
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
4
],
"after destroy block 4"
);
// And deref blocks 3, 2, 1, 0 in this order as a batch
// And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
3
]);
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
3
]);
assert_no_response
(
&
mut
rx
,
"after deref operation"
);
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
2
]);
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
2
]);
...
@@ -396,6 +484,7 @@ mod tests {
...
@@ -396,6 +484,7 @@ mod tests {
// Now destroy block 6
// Now destroy block 6
destroy_blocks
(
&
mut
manager
,
vec!
[
6
]);
destroy_blocks
(
&
mut
manager
,
vec!
[
6
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
6
],
"after block 6 eviction"
);
// And deref blocks 5, 1, 0 as a batch
// And deref blocks 5, 1, 0 as a batch
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
]);
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
]);
...
@@ -406,6 +495,7 @@ mod tests {
...
@@ -406,6 +495,7 @@ mod tests {
// Now use 0, 1, 2, 7, 8, 9 as a batch
// Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
7
,
8
,
9
]);
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
7
,
8
,
9
]);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
7
,
8
,
9
],
"after [7, 8, 9] use"
);
// Check that the inactive_blocks is size 2, and contains 3 and 5
// Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
5
]);
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
5
]);
...
@@ -420,8 +510,14 @@ mod tests {
...
@@ -420,8 +510,14 @@ mod tests {
// Now use blocks 10, 11, 12 as a batch
// Now use blocks 10, 11, 12 as a batch
use_blocks
(
&
mut
manager
,
vec!
[
10
,
11
,
12
]);
use_blocks
(
&
mut
manager
,
vec!
[
10
,
11
,
12
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
3
],
"after block 5 eviction"
);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
10
,
11
,
12
],
"after [10, 11, 12] use"
);
// Check that the inactive_blocks is size 1 and contains only 5
// Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks
(
&
manager
,
1
,
&
[
5
]);
assert_inactive_blocks
(
&
manager
,
1
,
&
[
5
]);
use_blocks
(
&
mut
manager
,
vec!
[
13
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
5
],
"after block 5 eviction"
);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
13
],
"after block 13 use"
);
}
}
}
}
lib/llm/src/mocker/protocols.rs
View file @
f0652d89
...
@@ -13,12 +13,16 @@
...
@@ -13,12 +13,16 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
derive_builder
::
Builder
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
use
crate
::
kv_router
::
protocols
::{
ExternalSequenceBlockHash
,
KvCacheEventData
,
KvCacheRemoveData
,
KvCacheStoreData
,
KvCacheStoredBlockData
,
LocalBlockHash
,
};
pub
type
Token
=
u32
;
pub
type
Token
=
u32
;
pub
type
LocalBlockHash
=
u64
;
/// A global hash identifier for blocks
pub
type
GlobalHash
=
u64
;
pub
type
GlobalHash
=
u64
;
pub
type
NumBlocks
=
usize
;
pub
type
NumBlocks
=
usize
;
...
@@ -39,12 +43,19 @@ impl Default for UniqueBlock {
...
@@ -39,12 +43,19 @@ impl Default for UniqueBlock {
}
}
/// Represents different block movement operations in the cache
/// Represents different block movement operations in the cache
/// For Use and Promote variants, parent hash is the second field
#[derive(Debug,
Clone,
PartialEq,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
PartialEq,
Serialize,
Deserialize)]
pub
enum
MoveBlock
{
pub
enum
MoveBlock
{
Use
(
Vec
<
UniqueBlock
>
,
Option
<
f64
>
),
Use
(
Vec
<
UniqueBlock
>
),
Destroy
(
Vec
<
UniqueBlock
>
),
Destroy
(
Vec
<
UniqueBlock
>
),
Deref
(
Vec
<
UniqueBlock
>
),
Deref
(
Vec
<
UniqueBlock
>
),
Promote
(
Uuid
,
GlobalHash
),
Promote
(
Uuid
,
GlobalHash
,
Option
<
u64
>
),
}
#[derive(Debug,
Clone,
PartialEq,
Serialize,
Deserialize)]
pub
enum
MoveBlockResponse
{
Store
(
Vec
<
GlobalHash
>
,
Option
<
u64
>
),
Remove
(
Vec
<
GlobalHash
>
),
}
}
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
...
@@ -52,15 +63,86 @@ pub struct DirectRequest {
...
@@ -52,15 +63,86 @@ pub struct DirectRequest {
pub
tokens
:
Vec
<
Token
>
,
pub
tokens
:
Vec
<
Token
>
,
pub
max_output_tokens
:
usize
,
pub
max_output_tokens
:
usize
,
pub
uuid
:
Option
<
Uuid
>
,
pub
uuid
:
Option
<
Uuid
>
,
pub
dp_rank
:
Option
<
u32
>
,
}
}
/// Represents the cost of prefilling content in the cache
/// Represents the cost of prefilling content in the cache
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
PrefillCost
{
pub
struct
PrefillCost
{
pub
new_blocks
:
usize
,
pub
new_tokens
:
usize
,
pub
new_tokens
:
usize
,
pub
prefill_compute
:
f64
,
pub
prefill_compute
:
f64
,
}
}
/// Signal for output token generation with completion status
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
OutputSignal
{
pub
uuid
:
Uuid
,
pub
completed
:
bool
,
}
/// Configuration arguments for MockVllmEngine
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Builder)]
#[builder(pattern
=
"owned"
,
build_fn(public))]
pub
struct
MockEngineArgs
{
#[builder(default
=
"16384"
)]
pub
num_gpu_blocks
:
usize
,
#[builder(default
=
"64"
)]
pub
block_size
:
usize
,
// This was 1024 in the past but reverted back to 256
#[builder(default
=
Some(
256
))]
pub
max_num_seqs
:
Option
<
usize
>
,
// default for open api server, for llm class it's 16384
#[builder(default
=
Some(
8192
))]
pub
max_num_batched_tokens
:
Option
<
usize
>
,
#[builder(default
=
true
)]
pub
enable_prefix_caching
:
bool
,
#[builder(default
=
"0.01"
)]
pub
watermark
:
f64
,
#[builder(default
=
"1.0"
)]
pub
speedup_ratio
:
f64
,
#[builder(default
=
"1"
)]
pub
dp_size
:
u32
,
}
impl
MockEngineArgs
{
pub
fn
builder
()
->
MockEngineArgsBuilder
{
MockEngineArgsBuilder
::
default
()
}
}
/// Note: This assumes block_hash and tokens_hash are the same, which is not correct in rare cases
/// where the sequence-aware hash differs from the token content hash.
pub
fn
block_response_to_kv_event
(
response
:
MoveBlockResponse
)
->
KvCacheEventData
{
match
response
{
MoveBlockResponse
::
Store
(
full_blocks
,
parent_hash
)
=>
{
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
parent_hash
.map
(
ExternalSequenceBlockHash
),
blocks
:
full_blocks
.into_iter
()
.map
(|
block
|
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
block
),
tokens_hash
:
LocalBlockHash
(
block
),
})
.collect
(),
})
}
MoveBlockResponse
::
Remove
(
full_blocks
)
=>
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
full_blocks
.into_iter
()
.map
(
ExternalSequenceBlockHash
)
.collect
(),
}),
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
...
...
lib/llm/src/mocker/scheduler.rs
View file @
f0652d89
This diff is collapsed.
Click to expand it.
lib/llm/src/mocker/sequence.rs
View file @
f0652d89
...
@@ -23,16 +23,23 @@ use uuid;
...
@@ -23,16 +23,23 @@ use uuid;
fn
create_unique_blocks_from_sequence
(
fn
create_unique_blocks_from_sequence
(
tokens
:
&
TokenBlockSequence
,
tokens
:
&
TokenBlockSequence
,
uuid
:
Option
<
uuid
::
Uuid
>
,
uuid
:
Option
<
uuid
::
Uuid
>
,
block_size
:
u32
,
block_size
:
usize
,
enable_prefix_caching
:
bool
,
)
->
Vec
<
UniqueBlock
>
{
)
->
Vec
<
UniqueBlock
>
{
let
mut
unique_blocks
:
Vec
<
UniqueBlock
>
=
tokens
let
mut
unique_blocks
:
Vec
<
UniqueBlock
>
=
tokens
.blocks
()
.blocks
()
.iter
()
.iter
()
.map
(|
block
|
UniqueBlock
::
FullBlock
(
block
.sequence_hash
()))
.map
(|
block
|
{
if
enable_prefix_caching
{
UniqueBlock
::
FullBlock
(
block
.sequence_hash
())
}
else
{
UniqueBlock
::
FullBlock
(
random
::
<
u64
>
())
}
})
.collect
();
.collect
();
// Only push the partial block if tokens count isn't a multiple of block_size
// Only push the partial block if tokens count isn't a multiple of block_size
if
tokens
.total_tokens
()
%
(
block_size
as
usize
)
!=
0
{
if
tokens
.total_tokens
()
%
block_size
!=
0
{
unique_blocks
.push
(
match
uuid
{
unique_blocks
.push
(
match
uuid
{
Some
(
uuid
)
=>
UniqueBlock
::
PartialBlock
(
uuid
),
Some
(
uuid
)
=>
UniqueBlock
::
PartialBlock
(
uuid
),
None
=>
UniqueBlock
::
default
(),
None
=>
UniqueBlock
::
default
(),
...
@@ -50,10 +57,7 @@ pub struct ActiveSequence {
...
@@ -50,10 +57,7 @@ pub struct ActiveSequence {
tokens
:
TokenBlockSequence
,
tokens
:
TokenBlockSequence
,
#[getter(copy)]
#[getter(copy)]
block_size
:
u32
,
block_size
:
usize
,
#[getter(copy)]
chunk_size
:
usize
,
// TODO: not actually used
#[getter(copy)]
#[getter(copy)]
max_output_tokens
:
usize
,
max_output_tokens
:
usize
,
...
@@ -61,10 +65,16 @@ pub struct ActiveSequence {
...
@@ -61,10 +65,16 @@ pub struct ActiveSequence {
#[getter(copy)]
#[getter(copy)]
generated_tokens
:
usize
,
generated_tokens
:
usize
,
#[getter(copy)]
already_generated_tokens
:
usize
,
#[getter(copy)]
#[getter(copy)]
num_input_tokens
:
usize
,
num_input_tokens
:
usize
,
creation_signal
:
Option
<
MoveBlock
>
,
creation_signal
:
Option
<
MoveBlock
>
,
#[getter(copy)]
enable_prefix_caching
:
bool
,
}
}
impl
ActiveSequence
{
impl
ActiveSequence
{
...
@@ -72,32 +82,33 @@ impl ActiveSequence {
...
@@ -72,32 +82,33 @@ impl ActiveSequence {
pub
fn
new
(
pub
fn
new
(
tokens
:
Vec
<
u32
>
,
tokens
:
Vec
<
u32
>
,
max_output_tokens
:
usize
,
max_output_tokens
:
usize
,
block_size
:
Option
<
u
32
>
,
block_size
:
Option
<
u
size
>
,
chunk_size
:
Option
<
usize
>
,
enable_prefix_caching
:
bool
,
)
->
Self
{
)
->
Self
{
let
block_size
=
block_size
.unwrap_or
(
64
);
let
block_size
=
block_size
.unwrap_or
(
64
);
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
let
chunk_size
=
chunk_size
.unwrap_or
(
256
);
let
num_input_tokens
=
tokens
.len
();
let
num_input_tokens
=
tokens
.len
();
let
tokens
=
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
,
None
);
let
tokens
=
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
let
unique_blocks
=
create_unique_blocks_from_sequence
(
&
tokens
,
None
,
block_size
);
let
unique_blocks
=
let
creation_signal
=
Some
(
MoveBlock
::
Use
(
unique_blocks
.clone
(),
None
));
create_unique_blocks_from_sequence
(
&
tokens
,
None
,
block_size
,
enable_prefix_caching
);
let
creation_signal
=
Some
(
MoveBlock
::
Use
(
unique_blocks
.clone
()));
Self
{
Self
{
unique_blocks
,
unique_blocks
,
tokens
,
tokens
,
block_size
,
block_size
,
chunk_size
,
max_output_tokens
,
max_output_tokens
,
generated_tokens
:
0
,
generated_tokens
:
0
,
already_generated_tokens
:
0
,
num_input_tokens
,
num_input_tokens
,
creation_signal
,
creation_signal
,
enable_prefix_caching
,
}
}
}
}
pub
fn
extra_tokens
(
&
self
)
->
u32
{
pub
fn
extra_tokens
(
&
self
)
->
u32
{
(
self
.len
()
%
self
.block_size
as
usize
)
as
u32
(
self
.len
()
%
self
.block_size
)
as
u32
}
}
pub
fn
len
(
&
self
)
->
usize
{
pub
fn
len
(
&
self
)
->
usize
{
...
@@ -112,20 +123,31 @@ impl ActiveSequence {
...
@@ -112,20 +123,31 @@ impl ActiveSequence {
pub
fn
new_with_signal
(
pub
fn
new_with_signal
(
tokens
:
Vec
<
u32
>
,
tokens
:
Vec
<
u32
>
,
max_output_tokens
:
usize
,
max_output_tokens
:
usize
,
block_size
:
Option
<
u
32
>
,
block_size
:
Option
<
u
size
>
,
chunk_size
:
Option
<
usize
>
,
enable_prefix_caching
:
bool
,
)
->
(
Self
,
Option
<
MoveBlock
>
)
{
)
->
(
Self
,
Option
<
MoveBlock
>
)
{
let
mut
sequence
=
Self
::
new
(
tokens
,
max_output_tokens
,
block_size
,
chunk_size
);
let
mut
sequence
=
Self
::
new
(
tokens
,
max_output_tokens
,
block_size
,
enable_prefix_caching
);
let
signal
=
sequence
.creation_signal
.take
();
let
signal
=
sequence
.creation_signal
.take
();
(
sequence
,
signal
)
(
sequence
,
signal
)
}
}
/// Get the parent hash from the second-to-last block if it exists and is a FullBlock
fn
get_parent_hash
(
&
self
)
->
Option
<
u64
>
{
if
self
.unique_blocks
.len
()
<
2
{
return
None
;
}
match
&
self
.unique_blocks
[
self
.unique_blocks
.len
()
-
2
]
{
UniqueBlock
::
FullBlock
(
hash
)
=>
Some
(
*
hash
),
_
=>
panic!
(
"Cannot have a partial block as parent"
),
}
}
/// Push a token to the sequence
/// Push a token to the sequence
pub
fn
push
(
&
mut
self
,
token
:
u32
)
->
Option
<
Vec
<
MoveBlock
>>
{
pub
fn
push
(
&
mut
self
,
token
:
u32
)
->
Option
<
Vec
<
MoveBlock
>>
{
self
.tokens
.append
(
token
)
.expect
(
"Token push failed."
);
self
.tokens
.append
(
token
)
.expect
(
"Token push failed."
);
self
.generated_tokens
+=
1
;
self
.generated_tokens
+=
1
;
if
self
.len
()
%
(
self
.block_size
as
usize
)
!=
1
{
if
self
.len
()
%
self
.block_size
!=
1
{
return
None
;
return
None
;
}
}
...
@@ -135,16 +157,24 @@ impl ActiveSequence {
...
@@ -135,16 +157,24 @@ impl ActiveSequence {
// Replace last partial block with full block if it exists
// Replace last partial block with full block if it exists
if
let
Some
(
UniqueBlock
::
PartialBlock
(
uuid
))
=
self
.unique_blocks
.last
()
.cloned
()
{
if
let
Some
(
UniqueBlock
::
PartialBlock
(
uuid
))
=
self
.unique_blocks
.last
()
.cloned
()
{
let
last_block_hash
=
self
.tokens
.last_complete_block
()
.unwrap
()
.sequence_hash
();
let
last_block_hash
=
if
self
.enable_prefix_caching
{
self
.tokens
.last_complete_block
()
.unwrap
()
.sequence_hash
()
}
else
{
random
::
<
u64
>
()
};
self
.unique_blocks
.pop
();
self
.unique_blocks
.pop
();
self
.unique_blocks
self
.unique_blocks
.push
(
UniqueBlock
::
FullBlock
(
last_block_hash
));
.push
(
UniqueBlock
::
FullBlock
(
last_block_hash
));
signals
.push
(
MoveBlock
::
Promote
(
uuid
,
last_block_hash
));
signals
.push
(
MoveBlock
::
Promote
(
uuid
,
last_block_hash
,
self
.get_parent_hash
(),
));
}
}
let
new_partial_block
=
UniqueBlock
::
default
();
let
new_partial_block
=
UniqueBlock
::
default
();
self
.unique_blocks
.push
(
new_partial_block
.clone
());
self
.unique_blocks
.push
(
new_partial_block
.clone
());
signals
.push
(
MoveBlock
::
Use
(
vec!
[
new_partial_block
]
,
None
));
signals
.push
(
MoveBlock
::
Use
(
vec!
[
new_partial_block
]));
Some
(
signals
)
Some
(
signals
)
}
}
...
@@ -204,15 +234,19 @@ impl ActiveSequence {
...
@@ -204,15 +234,19 @@ impl ActiveSequence {
}
}
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// maintaining the uuid of the last partial block
pub
fn
reset_with_signal
(
&
mut
self
)
->
Vec
<
MoveBlock
>
{
pub
fn
reset_with_signal
(
&
mut
self
)
->
Vec
<
MoveBlock
>
{
let
free_signal
=
self
.free_signal
();
let
free_signal
=
self
.free_signal
();
self
.tokens
.truncate
(
self
.num_input_tokens
)
.unwrap
();
self
.tokens
.truncate
(
self
.num_input_tokens
)
.unwrap
();
self
.unique_blocks
=
self
.unique_blocks
=
create_unique_blocks_from_sequence
(
create_unique_blocks_from_sequence
(
&
self
.tokens
,
None
,
self
.block_size
);
&
self
.tokens
,
None
,
self
.block_size
,
self
.enable_prefix_caching
,
);
self
.already_generated_tokens
=
self
.generated_tokens
.max
(
self
.already_generated_tokens
);
self
.generated_tokens
=
0
;
self
.generated_tokens
=
0
;
self
.creation_signal
=
Some
(
MoveBlock
::
Use
(
self
.unique_blocks
.clone
()
,
None
));
self
.creation_signal
=
Some
(
MoveBlock
::
Use
(
self
.unique_blocks
.clone
()));
free_signal
free_signal
}
}
...
@@ -223,7 +257,7 @@ impl ActiveSequence {
...
@@ -223,7 +257,7 @@ impl ActiveSequence {
self
.generated_tokens
=
self
.generated_tokens
.saturating_sub
(
1
);
self
.generated_tokens
=
self
.generated_tokens
.saturating_sub
(
1
);
// Reverts to the last full block
// Reverts to the last full block
if
self
.tokens
.total_tokens
()
%
(
self
.block_size
as
usize
)
==
0
{
if
self
.tokens
.total_tokens
()
%
self
.block_size
==
0
{
self
.unique_blocks
.pop
();
self
.unique_blocks
.pop
();
}
}
}
}
...
@@ -238,14 +272,14 @@ mod tests {
...
@@ -238,14 +272,14 @@ mod tests {
// Create a sequence with block size 16 initialized with tokens [0..15]
// Create a sequence with block size 16 initialized with tokens [0..15]
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
15
)
.collect
();
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
15
)
.collect
();
let
(
mut
seq1
,
signal1
)
=
let
(
mut
seq1
,
signal1
)
=
ActiveSequence
::
new_with_signal
(
initial_tokens
,
100
,
Some
(
16
),
Some
(
256
)
);
ActiveSequence
::
new_with_signal
(
initial_tokens
,
100
,
Some
(
16
),
true
);
assert_eq!
(
seq1
.num_input_tokens
(),
15
);
assert_eq!
(
seq1
.num_input_tokens
(),
15
);
assert_eq!
(
seq1
.len
(),
15
);
assert_eq!
(
seq1
.len
(),
15
);
// Check that we got a Use signal
// Check that we got a Use signal
assert
!
(
signal1
.is_some
());
assert
!
(
signal1
.is_some
());
match
&
signal1
{
match
&
signal1
{
Some
(
MoveBlock
::
Use
(
blocks
,
_
))
=>
{
Some
(
MoveBlock
::
Use
(
blocks
))
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert_eq!
(
blocks
.len
(),
1
);
}
}
_
=>
panic!
(
"Expected Use signal"
),
_
=>
panic!
(
"Expected Use signal"
),
...
@@ -264,33 +298,31 @@ mod tests {
...
@@ -264,33 +298,31 @@ mod tests {
let
signal_16
=
signal_16
.unwrap
();
let
signal_16
=
signal_16
.unwrap
();
assert_eq!
(
signal_16
.len
(),
2
);
assert_eq!
(
signal_16
.len
(),
2
);
// First signal should be Promote for the previous block
match
&
signal_16
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
assert_eq!
(
*
parent_hash
,
None
);
}
_
=>
panic!
(
"Expected Promote signal as second signal"
),
}
// Second signal should be Use for new partial block
// Second signal should be Use for new partial block
match
&
signal_16
[
1
]
{
match
&
signal_16
[
1
]
{
MoveBlock
::
Use
(
blocks
,
_
)
=>
{
MoveBlock
::
Use
(
blocks
)
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
}
_
=>
panic!
(
"Expected Use signal as first signal"
),
_
=>
panic!
(
"Expected Use signal as first signal"
),
}
}
// First signal should be Promote for the previous block
match
&
signal_16
[
0
]
{
MoveBlock
::
Promote
(
uuid
,
_
)
=>
{
// The uuid is generated dynamically, so we just check it exists
let
_
=
uuid
;
}
_
=>
panic!
(
"Expected Promote signal as second signal"
),
}
// Verify state after pushing tokens
// Verify state after pushing tokens
assert_eq!
(
seq1
.unique_blocks
()
.len
(),
2
);
// One full block and one partial block
assert_eq!
(
seq1
.unique_blocks
()
.len
(),
2
);
// One full block and one partial block
assert_eq!
(
seq1
.len
(),
17
);
assert_eq!
(
seq1
.len
(),
17
);
assert_eq!
(
seq1
.len
()
%
(
seq1
.block_size
()
as
usize
)
,
1
);
assert_eq!
(
seq1
.len
()
%
seq1
.block_size
(),
1
);
// Create another sequence with block size 16 initialized with tokens [0..17]
// Create another sequence with block size 16 initialized with tokens [0..17]
let
extended_tokens
:
Vec
<
u32
>
=
(
0
..
16
)
.collect
();
let
extended_tokens
:
Vec
<
u32
>
=
(
0
..
16
)
.collect
();
let
(
mut
seq2
,
_
)
=
let
(
mut
seq2
,
_
)
=
ActiveSequence
::
new_with_signal
(
extended_tokens
,
100
,
Some
(
16
),
true
);
ActiveSequence
::
new_with_signal
(
extended_tokens
,
100
,
Some
(
16
),
Some
(
256
));
seq2
.push
(
16
);
seq2
.push
(
16
);
seq2
.pop
();
seq2
.pop
();
seq2
.push
(
16
);
seq2
.push
(
16
);
...
@@ -335,12 +367,12 @@ mod tests {
...
@@ -335,12 +367,12 @@ mod tests {
"seq2 should have exactly 3 blocks"
"seq2 should have exactly 3 blocks"
);
);
assert_eq!
(
assert_eq!
(
seq1
.len
()
%
(
seq1
.block_size
()
as
usize
)
,
seq1
.len
()
%
seq1
.block_size
(),
1
,
1
,
"seq1 should have 1 partial token"
"seq1 should have 1 partial token"
);
);
assert_eq!
(
assert_eq!
(
seq2
.len
()
%
(
seq2
.block_size
()
as
usize
)
,
seq2
.len
()
%
seq2
.block_size
(),
1
,
1
,
"seq2 should have 1 partial token"
"seq2 should have 1 partial token"
);
);
...
@@ -352,9 +384,38 @@ mod tests {
...
@@ -352,9 +384,38 @@ mod tests {
"First two blocks should be identical"
"First two blocks should be identical"
);
);
// Push tokens 34..47 to seq1
for
token
in
33
..
48
{
seq1
.push
(
token
);
}
// Push token 48 and get the signal - this completes the block and triggers signals
let
signal
=
seq1
.push
(
48
);
let
signal
=
signal
.unwrap
();
// Check that signal[0] is promote
match
&
signal
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
// Check that the parent_hash matches unique_blocks[1], which should be a full block
if
let
UniqueBlock
::
FullBlock
(
expected_hash
)
=
seq1
.unique_blocks
()[
1
]
{
assert_eq!
(
*
parent_hash
,
Some
(
expected_hash
),
"Parent hash should match unique_blocks[1]"
);
}
else
{
panic!
(
"unique_blocks[1] should be a full block"
);
}
}
_
=>
panic!
(
"Expected Promote signal as first signal"
),
}
// Reset seq1 and check that it equals the original clone
// Reset seq1 and check that it equals the original clone
let
free_signals
=
seq1
.reset_with_signal
();
let
free_signals
=
seq1
.reset_with_signal
();
// 49 - 15 generated tokens
assert_eq!
(
seq1
.already_generated_tokens
,
34
);
// Verify the reset signals include proper cleanup events
// Verify the reset signals include proper cleanup events
assert
!
(
!
free_signals
.is_empty
());
assert
!
(
!
free_signals
.is_empty
());
}
}
...
@@ -363,13 +424,12 @@ mod tests {
...
@@ -363,13 +424,12 @@ mod tests {
fn
test_active_sequence_generate_signals
()
{
fn
test_active_sequence_generate_signals
()
{
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
14
)
.collect
();
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
14
)
.collect
();
let
(
mut
seq
,
signal
)
=
let
(
mut
seq
,
signal
)
=
ActiveSequence
::
new_with_signal
(
initial_tokens
,
5
,
Some
(
16
),
true
);
ActiveSequence
::
new_with_signal
(
initial_tokens
,
5
,
Some
(
16
),
Some
(
256
));
// Initial signal - should have received a Use signal for the partial block
// Initial signal - should have received a Use signal for the partial block
assert
!
(
signal
.is_some
());
assert
!
(
signal
.is_some
());
match
signal
{
match
signal
{
Some
(
MoveBlock
::
Use
(
blocks
,
_
))
=>
{
Some
(
MoveBlock
::
Use
(
blocks
))
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
}
...
@@ -385,25 +445,23 @@ mod tests {
...
@@ -385,25 +445,23 @@ mod tests {
let
signals_second
=
seq
.generate
();
let
signals_second
=
seq
.generate
();
assert_eq!
(
signals_second
.len
(),
2
);
assert_eq!
(
signals_second
.len
(),
2
);
// First signal should be Use for new partial block
// First signal should be Promote
match
&
signals_second
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
assert_eq!
(
*
parent_hash
,
None
);
}
_
=>
panic!
(
"Expected Promote signal as first signal after second token"
),
}
// Second signal should be Use for new partial block
match
&
signals_second
[
1
]
{
match
&
signals_second
[
1
]
{
MoveBlock
::
Use
(
blocks
,
_
)
=>
{
MoveBlock
::
Use
(
blocks
)
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
}
_
=>
panic!
(
"Expected Use signal as second signal after second token"
),
_
=>
panic!
(
"Expected Use signal as second signal after second token"
),
}
}
// Second signal should be Promote
match
&
signals_second
[
0
]
{
MoveBlock
::
Promote
(
uuid
,
hash
)
=>
{
// The uuid and hash values are generated dynamically, so we just check the event type
let
_
=
uuid
;
let
_
=
hash
;
}
_
=>
panic!
(
"Expected Promote signal as first signal after second token"
),
}
// Generate fourth token - should not trigger new signals as it's adding to partial block
// Generate fourth token - should not trigger new signals as it's adding to partial block
let
signals_third
=
seq
.generate
();
let
signals_third
=
seq
.generate
();
assert_eq!
(
signals_third
.len
(),
0
);
assert_eq!
(
signals_third
.len
(),
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment