Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f0652d89
"tests/vscode:/vscode.git/clone" did not exist on "dea0b201d6679a013f1238839591b0806130c29c"
Unverified
Commit
f0652d89
authored
Jul 01, 2025
by
Yan Ru Pei
Committed by
GitHub
Jul 01, 2025
Browse files
feat: vllm mocker enhancement (#1236)
parent
0d6cae85
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1654 additions
and
410 deletions
+1654
-410
lib/llm/src/mocker.rs
lib/llm/src/mocker.rs
+1
-0
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+764
-0
lib/llm/src/mocker/evictor.rs
lib/llm/src/mocker/evictor.rs
+96
-105
lib/llm/src/mocker/kv_manager.rs
lib/llm/src/mocker/kv_manager.rs
+166
-70
lib/llm/src/mocker/protocols.rs
lib/llm/src/mocker/protocols.rs
+86
-4
lib/llm/src/mocker/scheduler.rs
lib/llm/src/mocker/scheduler.rs
+424
-172
lib/llm/src/mocker/sequence.rs
lib/llm/src/mocker/sequence.rs
+117
-59
No files found.
lib/llm/src/mocker.rs
View file @
f0652d89
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
pub
mod
engine
;
pub
mod
evictor
;
pub
mod
evictor
;
pub
mod
kv_manager
;
pub
mod
kv_manager
;
pub
mod
protocols
;
pub
mod
protocols
;
...
...
lib/llm/src/mocker/engine.rs
0 → 100644
View file @
f0652d89
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MockSchedulerEngine - AsyncEngine wrapper around the Scheduler
//!
//! This module provides an AsyncEngine implementation that wraps the Scheduler
//! to provide streaming token generation with realistic timing simulation.
use
crate
::
kv_router
::
publisher
::
WorkerMetricsPublisher
;
use
crate
::
mocker
::
protocols
::
DirectRequest
;
use
crate
::
mocker
::
protocols
::{
MockEngineArgs
,
OutputSignal
};
use
crate
::
mocker
::
scheduler
::
Scheduler
;
use
crate
::
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
};
use
crate
::
protocols
::
TokenIdType
;
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
dynamo_runtime
::
DistributedRuntime
;
use
tokio_util
::
sync
::
CancellationToken
;
use
dynamo_runtime
::{
component
::
Component
,
engine
::
AsyncEngineContextProvider
,
pipeline
::{
async_trait
,
AsyncEngine
,
Error
,
ManyOut
,
ResponseStream
,
SingleIn
},
traits
::
DistributedRuntimeProvider
,
Result
,
};
use
crate
::
kv_router
::
protocols
::{
KvCacheEvent
,
KvCacheEventData
};
use
crate
::
kv_router
::
publisher
::
KvEventPublisher
;
use
futures
::
StreamExt
;
use
rand
::
Rng
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
tokio
::
sync
::{
mpsc
,
Mutex
,
OnceCell
};
use
tokio
::
time
::{
interval
,
Duration
};
use
tokio_stream
::
wrappers
::
ReceiverStream
;
use
uuid
::
Uuid
;
pub
const
MOCKER_COMPONENT
:
&
str
=
"mocker"
;
/// Generate a random token ID from 1k to 5k
fn
generate_random_token
()
->
TokenIdType
{
let
mut
rng
=
rand
::
rng
();
rng
.random_range
(
1000
..
5000
)
}
/// AsyncEngine wrapper around the Scheduler that generates random character tokens
#[derive(Clone)]
pub
struct
MockVllmEngine
{
active_requests
:
Arc
<
Mutex
<
HashMap
<
Uuid
,
mpsc
::
UnboundedSender
<
OutputSignal
>>>>
,
request_senders
:
Arc
<
OnceCell
<
Vec
<
mpsc
::
UnboundedSender
<
DirectRequest
>>>>
,
engine_args
:
MockEngineArgs
,
}
impl
MockVllmEngine
{
/// Create a new MockVllmEngine with the given parameters
pub
fn
new
(
args
:
MockEngineArgs
)
->
Self
{
Self
{
active_requests
:
Arc
::
new
(
Mutex
::
new
(
HashMap
::
new
())),
request_senders
:
Arc
::
new
(
OnceCell
::
new
()),
engine_args
:
args
,
}
}
pub
async
fn
start
(
&
self
,
component
:
Component
)
->
Result
<
()
>
{
let
cancel_token
=
component
.drt
()
.runtime
()
.child_token
();
let
(
schedulers
,
kv_event_receiver
)
=
self
.start_schedulers
(
self
.engine_args
.clone
(),
self
.active_requests
.clone
(),
cancel_token
.clone
(),
);
Self
::
start_metrics_publishing
(
&
schedulers
,
Some
(
component
.clone
()),
cancel_token
.clone
())
.await
?
;
// Start KV events publishing with the actual receivers from schedulers
if
self
.engine_args.enable_prefix_caching
{
Self
::
start_kv_events_publishing
(
kv_event_receiver
,
Some
(
component
.clone
()),
self
.engine_args.block_size
,
cancel_token
.clone
(),
)
.await
?
;
}
Ok
(())
}
pub
fn
direct
(
&
self
,
request
:
DirectRequest
,
dp_rank
:
usize
)
{
let
senders
=
self
.request_senders
.get
()
.expect
(
"Not initialized"
);
let
_
=
senders
[
dp_rank
]
.send
(
request
);
}
/// Create schedulers and spawn their background tasks for distributing token notifications
/// Returns schedulers and their corresponding KV event receivers
fn
start_schedulers
(
&
self
,
args
:
MockEngineArgs
,
active_requests
:
Arc
<
Mutex
<
HashMap
<
Uuid
,
mpsc
::
UnboundedSender
<
OutputSignal
>>>>
,
cancel_token
:
CancellationToken
,
)
->
(
Vec
<
Scheduler
>
,
Vec
<
mpsc
::
UnboundedReceiver
<
KvCacheEventData
>>
,
)
{
let
mut
schedulers
=
Vec
::
<
Scheduler
>
::
new
();
let
mut
kv_event_receivers
=
Vec
::
new
();
let
mut
senders
=
Vec
::
with_capacity
(
args
.dp_size
as
usize
);
// Create multiple schedulers and their background tasks
for
dp_rank
in
0
..
args
.dp_size
{
// Create a shared output channel that this scheduler will use
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
// Create a channel for KV events from this scheduler
let
(
kv_events_tx
,
kv_events_rx
)
=
mpsc
::
unbounded_channel
::
<
KvCacheEventData
>
();
let
scheduler
=
Scheduler
::
new
(
args
.clone
(),
Some
(
dp_rank
),
Some
(
output_tx
),
Some
(
kv_events_tx
),
// Pass the KV events sender to scheduler
Some
(
cancel_token
.clone
()),
);
senders
.push
(
scheduler
.request_sender
());
schedulers
.push
(
scheduler
);
kv_event_receivers
.push
(
kv_events_rx
);
// Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx));
let
active_requests_clone
=
active_requests
.clone
();
let
cancel_token_cloned
=
cancel_token
.clone
();
tokio
::
spawn
(
async
move
{
loop
{
tokio
::
select!
{
signal_result
=
output_rx
.recv
()
=>
{
let
Some
(
signal
)
=
signal_result
else
{
break
;
// Channel closed
};
// Notify the specific request that a token was generated
let
active
=
active_requests_clone
.lock
()
.await
;
if
let
Some
(
request_tx
)
=
active
.get
(
&
signal
.uuid
)
{
let
_
=
request_tx
.send
(
signal
);
}
}
_
=
cancel_token_cloned
.cancelled
()
=>
{
break
;
}
}
}
});
}
// Set the senders once
self
.request_senders
.set
(
senders
)
.expect
(
"Already initialized"
);
(
schedulers
,
kv_event_receivers
)
}
/// Start background tasks to poll and publish metrics every second
async
fn
start_metrics_publishing
(
schedulers
:
&
[
Scheduler
],
component
:
Option
<
Component
>
,
cancel_token
:
CancellationToken
,
)
->
Result
<
()
>
{
tracing
::
info!
(
"Creating metrics publisher"
);
let
metrics_publisher
=
Arc
::
new
(
WorkerMetricsPublisher
::
new
()
?
);
tracing
::
info!
(
"Metrics publisher created"
);
if
let
Some
(
comp
)
=
component
{
tracing
::
info!
(
"Creating metrics endpoint"
);
tokio
::
spawn
({
let
publisher
=
metrics_publisher
.clone
();
async
move
{
if
let
Err
(
e
)
=
publisher
.create_endpoint
(
comp
.clone
())
.await
{
tracing
::
error!
(
"Metrics endpoint failed: {e}"
);
}
}
});
// Give it a moment to start
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
tracing
::
info!
(
"Metrics endpoint started (background)"
);
}
tracing
::
info!
(
"Starting metrics background tasks"
);
for
(
dp_rank
,
scheduler
)
in
schedulers
.iter
()
.enumerate
()
{
let
scheduler
=
scheduler
.clone
();
let
publisher
=
metrics_publisher
.clone
();
let
dp_rank
=
dp_rank
as
u32
;
let
cancel_token
=
cancel_token
.clone
();
tokio
::
spawn
(
async
move
{
let
mut
interval
=
interval
(
Duration
::
from_millis
(
100
));
loop
{
tokio
::
select!
{
_
=
interval
.tick
()
=>
{
// Get metrics from scheduler
let
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
// Publish metrics
if
let
Err
(
e
)
=
publisher
.publish
(
Arc
::
new
(
metrics
))
{
tracing
::
warn!
(
"Failed to publish metrics for DP rank {dp_rank}: {e}"
);
}
else
{
tracing
::
trace!
(
"Published metrics for DP rank {}"
,
dp_rank
);
}
}
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
info!
(
"Metrics publishing cancelled for DP rank {dp_rank}"
);
break
;
}
}
}
});
}
tracing
::
info!
(
"Metrics background tasks started"
);
Ok
(())
}
/// Start background tasks to collect and publish KV events from schedulers
async
fn
start_kv_events_publishing
(
kv_event_receivers
:
Vec
<
mpsc
::
UnboundedReceiver
<
KvCacheEventData
>>
,
component
:
Option
<
Component
>
,
block_size
:
usize
,
cancel_token
:
CancellationToken
,
)
->
Result
<
()
>
{
tracing
::
info!
(
"Starting KV events publishing"
);
// Only start KV events publishing if we have a component
let
Some
(
comp
)
=
component
else
{
tracing
::
warn!
(
"No component provided, skipping KV events publishing"
);
return
Ok
(());
};
tracing
::
info!
(
"Component found for KV events publishing"
);
tracing
::
debug!
(
"Getting worker_id"
);
let
worker_id
=
comp
.drt
()
.primary_lease
()
.expect
(
"Cannot publish KV events without lease"
)
// ← This will PANIC on static!
.id
();
// let worker_id = 0;
tracing
::
debug!
(
"Worker_id set to: {worker_id}"
);
tracing
::
info!
(
"Creating KV event publisher"
);
let
kv_event_publisher
=
Arc
::
new
(
KvEventPublisher
::
new
(
comp
.clone
(),
worker_id
,
block_size
as
u32
,
None
,
)
?
);
tracing
::
info!
(
"KV event publisher created"
);
tracing
::
info!
(
"Starting KV event background tasks for {} receivers"
,
kv_event_receivers
.len
()
);
for
(
dp_rank
,
mut
kv_events_rx
)
in
kv_event_receivers
.into_iter
()
.enumerate
()
{
tracing
::
debug!
(
"Starting background task for DP rank {dp_rank}"
);
let
publisher
=
kv_event_publisher
.clone
();
let
dp_rank
=
dp_rank
as
u32
;
let
cancel_token
=
cancel_token
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
debug!
(
"Background task started for DP rank {dp_rank}"
);
loop
{
tokio
::
select!
{
// Receive actual KV events from the scheduler
Some
(
event_data
)
=
kv_events_rx
.recv
()
=>
{
// Convert KvCacheEventData to KvCacheEvent with random UUID as event_id
let
event
=
KvCacheEvent
{
event_id
:
Uuid
::
new_v4
()
.as_u128
()
as
u64
,
data
:
event_data
,
};
// Publish the event
if
let
Err
(
e
)
=
publisher
.publish
(
event
)
{
tracing
::
warn!
(
"Failed to publish KV event for DP rank {dp_rank}: {e}"
);
}
else
{
tracing
::
trace!
(
"Published KV event for DP rank {dp_rank}"
);
}
}
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
info!
(
"KV events publishing cancelled for DP rank {dp_rank}"
);
break
;
}
}
}
});
}
tracing
::
info!
(
"All KV event background tasks started"
);
Ok
(())
}
}
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
LLMEngineOutput
>
,
Error
>
for
MockVllmEngine
{
async
fn
generate
(
&
self
,
input
:
SingleIn
<
PreprocessedRequest
>
,
)
->
Result
<
ManyOut
<
LLMEngineOutput
>
,
Error
>
{
let
(
request
,
ctx
)
=
input
.into_parts
();
// Extract dp_rank from annotations if present
let
dp_rank
=
request
.annotations
.iter
()
.find_map
(|
ann
|
{
if
ann
.starts_with
(
"dp_rank:"
)
{
ann
.strip_prefix
(
"dp_rank:"
)
.and_then
(|
s
|
s
.parse
()
.ok
())
}
else
{
None
}
})
.unwrap_or
(
0
);
// Validate dp_rank
if
dp_rank
>=
self
.engine_args.dp_size
{
return
Err
(
Error
::
msg
(
format!
(
"dp_rank {} is out of bounds for dp_size {}"
,
dp_rank
,
self
.engine_args.dp_size
)));
}
let
request_uuid
=
ctx
.id
()
.parse
()
.unwrap_or
(
Uuid
::
new_v4
());
// Convert PreprocessedRequest to DirectRequest for scheduler
let
direct_request
=
DirectRequest
{
tokens
:
request
.token_ids
.clone
(),
max_output_tokens
:
request
.stop_conditions
.max_tokens
.expect
(
"max_output_tokens must be specified for mocker"
)
as
usize
,
uuid
:
Some
(
request_uuid
),
dp_rank
:
Some
(
dp_rank
),
};
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
{
let
mut
active
=
self
.active_requests
.lock
()
.await
;
active
.insert
(
request_uuid
,
request_tx
);
}
// Send the request to the appropriate scheduler based on dp_rank
self
.direct
(
direct_request
,
dp_rank
as
usize
);
// Create a simple channel for the stream
let
(
stream_tx
,
stream_rx
)
=
mpsc
::
channel
::
<
LLMEngineOutput
>
(
64
);
let
active_requests
=
self
.active_requests
.clone
();
let
async_context
=
ctx
.context
();
let
max_tokens
=
request
.stop_conditions.max_tokens
.unwrap_or
(
100
)
as
usize
;
// Spawn a task to handle the complex async logic
tokio
::
spawn
(
async
move
{
let
mut
token_count
=
0
;
loop
{
tokio
::
select!
{
maybe_signal
=
request_rx
.recv
()
=>
{
let
Some
(
signal
)
=
maybe_signal
else
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
error
(
"All output transmitters closed"
.to_string
()))
.await
;
break
;
};
if
signal
.completed
&&
token_count
<
max_tokens
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
error
(
"Completion signal received before max tokens reached"
.to_string
()))
.await
;
break
;
}
if
signal
.completed
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
length
())
.await
;
break
;
}
// Generate a new token
let
token_id
=
generate_random_token
();
token_count
+=
1
;
let
output
=
LLMEngineOutput
{
token_ids
:
vec!
[
token_id
],
tokens
:
None
,
// Let backend handle detokenization
text
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
finish_reason
:
None
,
index
:
None
,
};
if
stream_tx
.send
(
output
)
.await
.is_err
()
{
break
;
}
}
_
=
async_context
.stopped
()
=>
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
cancelled
())
.await
;
break
;
}
}
}
// Clean up: remove this request from active requests
let
mut
active
=
active_requests
.lock
()
.await
;
active
.remove
(
&
request_uuid
);
});
// Create a simple ReceiverStream which is naturally Send + Sync
let
stream
=
ReceiverStream
::
new
(
stream_rx
);
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
.context
()))
}
}
pub
struct
AnnotatedMockEngine
{
inner
:
Arc
<
MockVllmEngine
>
,
}
impl
AnnotatedMockEngine
{
pub
fn
new
(
inner
:
MockVllmEngine
,
distributed_runtime
:
DistributedRuntime
,
endpoint
:
dynamo_runtime
::
protocols
::
Endpoint
,
)
->
Self
{
let
inner
=
Arc
::
new
(
inner
);
let
inner_clone
=
inner
.clone
();
// Start background task to wait for component service and start the engine
tokio
::
spawn
(
async
move
{
loop
{
// Try to create component
let
Ok
(
namespace
)
=
distributed_runtime
.namespace
(
&
endpoint
.namespace
)
else
{
tracing
::
debug!
(
"Namespace not available yet, retrying..."
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
continue
;
};
let
Ok
(
component
)
=
namespace
.component
(
&
endpoint
.component
)
else
{
tracing
::
debug!
(
"Component not available yet, retrying..."
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
continue
;
};
// Check if service is available by trying to list instances
let
Ok
(
instances
)
=
component
.list_instances
()
.await
else
{
tracing
::
debug!
(
"Cannot list instances yet, retrying..."
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
continue
;
};
if
instances
.is_empty
()
{
tracing
::
debug!
(
"No instances available yet, retrying..."
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
continue
;
}
tracing
::
info!
(
"Component service is now available, starting mocker engine"
);
// Start the engine with the component
if
let
Err
(
e
)
=
inner_clone
.start
(
component
)
.await
{
tracing
::
error!
(
"Failed to start mocker engine: {e}"
);
}
break
;
}
});
Self
{
inner
}
}
}
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
Error
>
for
AnnotatedMockEngine
{
async
fn
generate
(
&
self
,
input
:
SingleIn
<
PreprocessedRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
Error
>
{
let
stream
=
self
.inner
.generate
(
input
)
.await
?
;
let
context
=
stream
.context
();
// Convert stream of LLMEngineOutput to Annotated<LLMEngineOutput>
let
annotated_stream
=
stream
.map
(
Annotated
::
from_data
);
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
annotated_stream
),
context
))
}
}
/// Create a mocker engine as ExecutionContext
pub
async
fn
make_mocker_engine
(
distributed_runtime
:
DistributedRuntime
,
endpoint
:
dynamo_runtime
::
protocols
::
Endpoint
,
args
:
MockEngineArgs
,
)
->
Result
<
crate
::
backend
::
ExecutionContext
,
Error
>
{
// Create the mocker engine
tracing
::
info!
(
"Creating mocker engine (service will be started in background)"
);
let
annotated_engine
=
AnnotatedMockEngine
::
new
(
MockVllmEngine
::
new
(
args
),
distributed_runtime
,
endpoint
);
Ok
(
Arc
::
new
(
annotated_engine
))
}
#[cfg(test)]
mod
integration_tests
{
use
super
::
*
;
use
crate
::
kv_router
::
indexer
::
RouterEvent
;
use
crate
::
kv_router
::
KV_EVENT_SUBJECT
;
use
crate
::
protocols
::
common
::{
SamplingOptions
,
StopConditions
};
use
dynamo_runtime
::{
pipeline
::
Context
,
pipeline
::{
network
::
Ingress
,
PushRouter
},
traits
::
events
::
EventSubscriber
,
DistributedRuntime
,
Worker
,
};
use
futures
::
StreamExt
;
use
tokio
::
time
::
timeout
;
#[tokio::test]
#[ignore]
// Run with: cargo test -- --ignored
async
fn
test_mock_vllm_engine_full_integration
()
->
Result
<
()
>
{
const
DP_SIZE
:
u32
=
2
;
const
TOKENS_PER_REQUEST
:
usize
=
20
;
const
BLOCK_SIZE
:
usize
=
2
;
// Create runtime and distributed runtime
let
worker
=
Worker
::
from_settings
()
?
;
let
runtime
=
worker
.runtime
();
let
distributed
=
DistributedRuntime
::
from_settings
(
runtime
.clone
())
.await
?
;
tracing
::
info!
(
"✓ Runtime and distributed runtime created"
);
// Create component for MockVllmEngine (needed for publishers)
let
test_component
=
distributed
.namespace
(
"test"
)
?
.component
(
MOCKER_COMPONENT
)
?
.service_builder
()
.create
()
.await
?
;
tracing
::
info!
(
"✓ Test component created"
);
// Create MockVllmEngine WITH component (enables publishers)
let
args
=
MockEngineArgs
::
builder
()
.speedup_ratio
(
10.0
)
.dp_size
(
DP_SIZE
)
.block_size
(
BLOCK_SIZE
)
.build
()
.unwrap
();
let
engine
=
MockVllmEngine
::
new
(
args
);
engine
.start
(
test_component
.clone
())
.await
?
;
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
500
))
.await
;
let
engine
=
Arc
::
new
(
engine
);
tracing
::
info!
(
"✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}"
);
// Set up KV events subscriber
let
mut
kv_events_subscriber
=
test_component
.subscribe
(
KV_EVENT_SUBJECT
)
.await
?
;
tracing
::
info!
(
"✓ KV events subscriber created"
);
// Wrap with Ingress and register with component/endpoint
let
ingress
=
Ingress
::
for_engine
(
engine
)
?
;
tracing
::
info!
(
"✓ Ingress wrapper created"
);
// Start the server in background
let
server_handle
=
tokio
::
spawn
({
let
test_component
=
test_component
.clone
();
async
move
{
if
let
Err
(
e
)
=
test_component
.endpoint
(
"generate"
)
.endpoint_builder
()
.handler
(
ingress
)
.start
()
.await
{
eprintln!
(
"❌ Generate endpoint failed: {e}"
);
}
}
});
tracing
::
info!
(
"✓ Server started in background"
);
// Give server time to start
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
500
))
.await
;
tracing
::
info!
(
"✓ Server startup delay completed"
);
// Print all registered instances from etcd
match
test_component
.list_instances
()
.await
{
Ok
(
instances
)
=>
{
tracing
::
info!
(
"📋 Found {} registered instances:"
,
instances
.len
());
for
instance
in
instances
{
tracing
::
info!
(
" • {}/{}/{} (ID: {})"
,
instance
.namespace
,
instance
.component
,
instance
.endpoint
,
instance
.instance_id
);
}
}
Err
(
e
)
=>
{
tracing
::
error!
(
"❌ Failed to list instances: {e}"
);
}
}
// Create client
let
client
=
distributed
.namespace
(
"test"
)
?
.component
(
MOCKER_COMPONENT
)
?
.endpoint
(
"generate"
)
.client
()
.await
?
;
tracing
::
info!
(
"✓ Client created"
);
let
router
=
PushRouter
::
from_client
(
client
,
Default
::
default
())
.await
?
;
tracing
::
info!
(
"✓ Router created"
);
// Create test requests for both DP workers
let
create_request
=
|
tokens
:
Vec
<
TokenIdType
>
,
dp_rank
:
u32
|
PreprocessedRequest
{
token_ids
:
tokens
,
batch_token_ids
:
None
,
stop_conditions
:
StopConditions
{
max_tokens
:
Some
(
TOKENS_PER_REQUEST
as
u32
),
..
Default
::
default
()
},
sampling_options
:
SamplingOptions
::
default
(),
eos_token_ids
:
vec!
[],
mdc_sum
:
None
,
annotations
:
vec!
[
format!
(
"dp_rank:{dp_rank}"
)],
estimated_prefix_hit_num_blocks
:
None
,
};
let
requests
=
vec!
[
create_request
(
vec!
[
1
,
2
,
3
,
4
,
5
],
0
),
create_request
(
vec!
[
1
,
2
,
3
,
4
,
5
],
0
),
create_request
(
vec!
[
1
,
2
,
3
,
4
,
5
],
1
),
create_request
(
vec!
[
1
,
2
,
3
,
4
,
5
],
1
),
];
tracing
::
info!
(
"✓ Test requests created ({} requests total)"
,
requests
.len
()
);
// Test each request
for
(
i
,
request
)
in
requests
.into_iter
()
.enumerate
()
{
tracing
::
info!
(
"Testing request {}"
,
i
+
1
);
let
response_stream
=
router
.generate
(
Context
::
new
(
request
))
.await
?
;
let
responses
:
Vec
<
LLMEngineOutput
>
=
response_stream
.collect
()
.await
;
// Should have at least one response
assert
!
(
!
responses
.is_empty
(),
"Request {} should produce at least one response"
,
i
+
1
);
// Count total tokens generated (excluding final message)
let
mut
total_tokens
=
0
;
let
mut
has_finish_reason
=
false
;
for
response
in
&
responses
{
total_tokens
+=
response
.token_ids
.len
();
if
response
.finish_reason
.is_some
()
{
has_finish_reason
=
true
;
}
}
// Should have a finish reason in the last response
assert
!
(
has_finish_reason
,
"Request {} should have a finish reason"
,
i
+
1
);
// Verify we got approximately the expected number of tokens
assert
!
(
total_tokens
<=
TOKENS_PER_REQUEST
+
1
,
// +1 for potential final empty response
"Request {} generated {} tokens, expected at most {}"
,
i
+
1
,
total_tokens
,
TOKENS_PER_REQUEST
+
1
);
tracing
::
info!
(
"✓ Request {} completed successfully with {} tokens"
,
i
+
1
,
total_tokens
);
}
tracing
::
info!
(
"🎉 All requests completed successfully!"
);
// Try to receive at least one KV event with 100ms timeout
tracing
::
info!
(
"Waiting for KV event with 100ms timeout..."
);
let
msg
=
timeout
(
Duration
::
from_millis
(
100
),
kv_events_subscriber
.next
())
.await
.map_err
(|
_
|
Error
::
msg
(
"Timeout waiting for KV event"
))
?
.ok_or_else
(||
Error
::
msg
(
"KV events stream ended unexpectedly"
))
?
;
match
serde_json
::
from_slice
::
<
RouterEvent
>
(
&
msg
.payload
)
{
Ok
(
event
)
=>
{
tracing
::
info!
(
"✓ Received KV event: {event:?}"
);
}
Err
(
e
)
=>
{
return
Err
(
Error
::
msg
(
format!
(
"Failed to deserialize KV event: {e}"
)));
}
}
// Use KvMetricsAggregator to get metrics more easily
let
cancel_token
=
test_component
.drt
()
.runtime
()
.child_token
();
let
metrics_aggregator
=
crate
::
kv_router
::
metrics_aggregator
::
KvMetricsAggregator
::
new
(
test_component
.clone
(),
cancel_token
,
)
.await
;
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
500
))
.await
;
let
processed_endpoints
=
metrics_aggregator
.get_endpoints
();
tracing
::
info!
(
"Found {} metrics endpoints"
,
processed_endpoints
.endpoints
.len
()
);
// Verify we found at least one metrics endpoint
assert
!
(
!
processed_endpoints
.endpoints
.is_empty
(),
"Should find at least one metrics endpoint"
);
tracing
::
info!
(
"✓ Successfully found {} metrics endpoints"
,
processed_endpoints
.endpoints
.len
()
);
// Verify the metrics endpoints contain valid data
for
(
worker_id
,
endpoint
)
in
&
processed_endpoints
.endpoints
{
tracing
::
info!
(
"✓ Worker {} metrics: {:?}"
,
worker_id
,
endpoint
.data
);
}
tracing
::
info!
(
"🎉 Event verification completed!"
);
// Cleanup
distributed
.shutdown
();
server_handle
.await
?
;
Ok
(())
}
}
lib/llm/src/mocker/evictor.rs
View file @
f0652d89
...
@@ -13,167 +13,158 @@
...
@@ -13,167 +13,158 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
std
::
cmp
::
Eq
;
use
std
::
cmp
::
{
Eq
,
Ordering
}
;
use
std
::
collections
::{
HashMap
,
VecDeque
};
use
std
::
collections
::{
BTreeSet
,
HashMap
};
use
std
::
hash
::
Hash
;
use
std
::
hash
::
Hash
;
use
std
::
time
::
Instant
;
/// A wrapper for (T, counter) that implements Ord based only on counter
#[derive(Debug,
Clone,
Eq,
PartialEq)]
struct
PriorityItem
<
T
>
{
item
:
T
,
counter
:
i64
,
}
impl
<
T
:
Eq
>
Ord
for
PriorityItem
<
T
>
{
fn
cmp
(
&
self
,
other
:
&
Self
)
->
Ordering
{
self
.counter
.cmp
(
&
other
.counter
)
}
}
impl
<
T
:
Eq
>
PartialOrd
for
PriorityItem
<
T
>
{
fn
partial_cmp
(
&
self
,
other
:
&
Self
)
->
Option
<
Ordering
>
{
Some
(
self
.cmp
(
other
))
}
}
/// An LRU evictor that maintains objects and evicts them based on their
/// An LRU evictor that maintains objects and evicts them based on their
/// last accessed time. Implements a "lazy" eviction mechanism where:
/// priority counter. Lower counter values are evicted first.
/// 1. The priority queue does not immediately reflect updates or removes
/// 2. Objects are pushed to the queue in order of increasing priority (older objects first)
/// 3. The user must ensure objects are added in correct priority (temporal order)
/// 4. Remove and update operations are lazy - entries remain in the queue until
/// they are either evicted or cleaned up during maintenance
#[derive(Debug)]
#[derive(Debug)]
pub
struct
LRUEvictor
<
T
:
Clone
+
Eq
+
Hash
>
{
pub
struct
LRUEvictor
<
T
:
Clone
+
Eq
+
Hash
>
{
free_table
:
HashMap
<
T
,
f
64
>
,
free_table
:
HashMap
<
T
,
i
64
>
,
priority_queue
:
VecDeque
<
(
T
,
f64
)
>
,
priority_queue
:
BTreeSet
<
PriorityItem
<
T
>
>
,
cleanup_threshold
:
usize
,
positive_counter
:
i64
,
start_time
:
Instant
,
negative_counter
:
i64
,
}
}
impl
<
T
:
Clone
+
Eq
+
Hash
>
Default
for
LRUEvictor
<
T
>
{
impl
<
T
:
Clone
+
Eq
+
Hash
>
Default
for
LRUEvictor
<
T
>
{
fn
default
()
->
Self
{
fn
default
()
->
Self
{
Self
{
Self
{
free_table
:
HashMap
::
new
(),
free_table
:
HashMap
::
new
(),
priority_queue
:
VecDeque
::
new
(),
priority_queue
:
BTreeSet
::
new
(),
cleanup_threshold
:
5
0
,
positive_counter
:
0
,
start_time
:
Instant
::
now
()
,
negative_counter
:
0
,
}
}
}
}
}
}
impl
<
T
:
Clone
+
Eq
+
Hash
>
LRUEvictor
<
T
>
{
impl
<
T
:
Clone
+
Eq
+
Hash
>
LRUEvictor
<
T
>
{
/// Create a new LRUEvictor with the default cleanup threshold
pub
fn
new
(
_
cleanup_threshold
:
usize
)
->
Self
{
pub
fn
new
(
cleanup_threshold
:
usize
)
->
Self
{
Self
::
default
()
Self
{
cleanup_threshold
,
..
Default
::
default
()
}
}
}
/// Get the current timestamp as seconds since initialization
pub
fn
keys
(
&
self
)
->
std
::
collections
::
hash_map
::
Keys
<
'_
,
T
,
i64
>
{
pub
fn
current_timestamp
(
&
self
)
->
f64
{
self
.free_table
.keys
()
self
.start_time
.elapsed
()
.as_secs_f64
()
}
}
/// Get an iterator over the keys in the evictor
fn
update
(
&
mut
self
,
object
:
T
,
counter
:
i64
)
{
pub
fn
keys
(
&
self
)
->
std
::
collections
::
hash_map
::
Keys
<
'_
,
T
,
f64
>
{
self
.free_table
.insert
(
object
.clone
(),
counter
);
self
.free_table
.keys
()
self
.priority_queue
.insert
(
PriorityItem
{
item
:
object
,
counter
,
});
}
}
/// Insert or update an object in the evictor with current timestamp
pub
fn
insert
(
&
mut
self
,
object
:
T
)
{
pub
fn
insert
(
&
mut
self
,
object
:
T
)
{
let
timestamp
=
self
.current_timestamp
();
// Remove old entry if it exists
self
._insert
(
object
,
timestamp
);
if
let
Some
(
&
old_counter
)
=
self
.free_table
.get
(
&
object
)
{
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
:
old_counter
,
});
}
}
/// Check if the evictor contains the given object
// Increment positive counter and insert
pub
fn
contains
(
&
self
,
object
:
&
T
)
->
bool
{
self
.positive_counter
+=
1
;
self
.free_table
.contains_key
(
object
)
let
counter
=
self
.positive_counter
;
self
.update
(
object
,
counter
);
}
}
/// Evict an object based on LRU policy
/// Push an object to the front with negative counter (highest priority for eviction)
/// Returns the evicted object or None if no objects are available
pub
fn
push_front
(
&
mut
self
,
object
:
T
)
{
pub
fn
evict
(
&
mut
self
)
->
Option
<
T
>
{
// Remove old entry if it exists
if
self
.free_table
.is_empty
()
{
if
let
Some
(
&
old_counter
)
=
self
.free_table
.get
(
&
object
)
{
return
None
;
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
:
old_counter
,
});
}
}
while
let
Some
((
object
,
last_accessed
))
=
self
.priority_queue
.pop_front
()
{
// Decrement negative counter and insert
let
Some
(
&
current_last_accessed
)
=
self
.free_table
.get
(
&
object
)
else
{
self
.negative_counter
-=
1
;
continue
;
// entry is already removed
let
counter
=
self
.negative_counter
;
};
if
current_last_accessed
==
last_accessed
{
self
.update
(
object
,
counter
);
self
.free_table
.remove
(
&
object
);
return
Some
(
object
);
}
// otherwise entry is stale
}
}
None
pub
fn
contains
(
&
self
,
object
:
&
T
)
->
bool
{
self
.free_table
.contains_key
(
object
)
}
}
/// Insert or update an object in the evictor
/// Evict an object based on LRU policy (lowest counter value)
fn
_
insert
(
&
mut
self
,
object
:
T
,
last_accessed
:
f64
)
{
/// Returns the evicted object or None if no objects are available
self
.free_table
.insert
(
object
.clone
(),
last_accessed
);
pub
fn
evict
(
&
mut
self
)
->
Option
<
T
>
{
self
.priority_queue
.push_back
((
object
,
last_accessed
));
self
.priority_queue
.pop_first
()
.map
(|
item
|
{
self
.cleanup_if_necessary
();
self
.free_table
.remove
(
&
item
.item
);
item
.item
})
}
}
/// Remove an object from the evictor
/// We don't remove from the priority queue immediately, as that would be inefficient
/// Outdated entries will be filtered out during eviction or cleanup
pub
fn
remove
(
&
mut
self
,
object
:
&
T
)
->
bool
{
pub
fn
remove
(
&
mut
self
,
object
:
&
T
)
->
bool
{
self
.free_table
.remove
(
object
)
.is_some
()
let
Some
(
&
counter
)
=
self
.free_table
.get
(
object
)
else
{
return
false
;
};
self
.free_table
.remove
(
object
);
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
,
});
true
}
}
/// Get the number of objects in the evictor
pub
fn
len
(
&
self
)
->
usize
{
pub
fn
len
(
&
self
)
->
usize
{
self
.free_table
.len
()
self
.free_table
.len
()
}
}
/// Check if the evictor is empty
pub
fn
is_empty
(
&
self
)
->
bool
{
pub
fn
is_empty
(
&
self
)
->
bool
{
self
.free_table
.is_empty
()
self
.free_table
.is_empty
()
}
}
/// Check if cleanup is necessary and perform it if needed
fn
cleanup_if_necessary
(
&
mut
self
)
{
if
self
.priority_queue
.len
()
>
self
.cleanup_threshold
*
self
.free_table
.len
()
{
self
.cleanup
();
}
}
/// Clean up the priority queue by removing outdated entries
fn
cleanup
(
&
mut
self
)
{
let
mut
new_priority_queue
=
VecDeque
::
new
();
for
(
object
,
timestamp
)
in
self
.priority_queue
.drain
(
..
)
{
let
Some
(
&
current_timestamp
)
=
self
.free_table
.get
(
&
object
)
else
{
continue
;
};
if
current_timestamp
==
timestamp
{
new_priority_queue
.push_back
((
object
,
timestamp
));
}
}
self
.priority_queue
=
new_priority_queue
;
}
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
use
rstest
::
rstest
;
#[rstest]
#[test]
#[case(
1
)]
fn
test_lru_evictor_eviction_order
()
{
#[case(
2
)]
// Create a new LRUEvictor
#[case(
3
)]
let
mut
evictor
=
LRUEvictor
::
<
i32
>
::
new
(
1
);
// threshold value doesn't matter anymore
fn
test_lru_evictor_eviction_order
(
#[case]
threshold
:
usize
)
{
// Create a new LRUEvictor with the given cleanup threshold
let
mut
evictor
=
LRUEvictor
::
<
i32
>
::
new
(
threshold
);
// Add items in the specified order
with small delays between each
// Add items in the specified order
evictor
.insert
(
4
);
evictor
.insert
(
4
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
3
);
evictor
.insert
(
3
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
2
);
evictor
.insert
(
2
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
1
);
evictor
.insert
(
1
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
5
);
evictor
.insert
(
5
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
1
);
// Updates counter for 1
evictor
.insert
(
1
);
// Updates timestamp for 1
evictor
.insert
(
4
);
// Updates counter for 4
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
2
);
// Updates counter for 2
evictor
.insert
(
4
);
// Updates timestamp for 4
evictor
.push_front
(
4
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
2
);
// Updates timestamp for 2
// Verify the eviction order
// Verify the eviction order
println!
(
"Testing with threshold {}"
,
threshold
);
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
4
);
let
evicted
=
evictor
.evict
()
.unwrap
();
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
3
);
assert_eq!
(
evicted
,
3
);
let
evicted
=
evictor
.evict
()
.unwrap
();
let
evicted
=
evictor
.evict
()
.unwrap
();
...
@@ -181,11 +172,11 @@ mod tests {
...
@@ -181,11 +172,11 @@ mod tests {
let
evicted
=
evictor
.evict
()
.unwrap
();
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
1
);
assert_eq!
(
evicted
,
1
);
let
evicted
=
evictor
.evict
()
.unwrap
();
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
4
);
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
2
);
assert_eq!
(
evicted
,
2
);
let
evicted
=
evictor
.evict
();
let
evicted
=
evictor
.evict
();
assert_eq!
(
evicted
,
None
);
assert_eq!
(
evicted
,
None
);
assert_eq!
(
evictor
.len
(),
0
);
assert_eq!
(
evictor
.len
(),
0
);
}
}
// ... existing test_push_front test ...
}
}
lib/llm/src/mocker/kv_manager.rs
View file @
f0652d89
...
@@ -46,10 +46,11 @@
...
@@ -46,10 +46,11 @@
//! implementation of the main block manager.
//! implementation of the main block manager.
use
crate
::
mocker
::
evictor
::
LRUEvictor
;
use
crate
::
mocker
::
evictor
::
LRUEvictor
;
use
crate
::
mocker
::
protocols
::{
MoveBlock
,
PrefillCost
,
UniqueBlock
};
use
crate
::
mocker
::
protocols
::{
MoveBlock
,
MoveBlockResponse
,
PrefillCost
,
UniqueBlock
};
use
crate
::
mocker
::
sequence
::
ActiveSequence
;
use
crate
::
mocker
::
sequence
::
ActiveSequence
;
use
derive_getters
::
Getters
;
use
derive_getters
::
Getters
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
collections
::{
HashMap
,
HashSet
};
use
tokio
::
sync
::
mpsc
;
#[derive(Getters)]
#[derive(Getters)]
pub
struct
KvManager
{
pub
struct
KvManager
{
...
@@ -57,17 +58,27 @@ pub struct KvManager {
...
@@ -57,17 +58,27 @@ pub struct KvManager {
max_capacity
:
usize
,
max_capacity
:
usize
,
#[getter(copy)]
#[getter(copy)]
block_size
:
u
32
,
block_size
:
u
size
,
active_blocks
:
HashMap
<
UniqueBlock
,
usize
>
,
active_blocks
:
HashMap
<
UniqueBlock
,
usize
>
,
inactive_blocks
:
LRUEvictor
<
UniqueBlock
>
,
inactive_blocks
:
LRUEvictor
<
UniqueBlock
>
,
all_blocks
:
HashSet
<
UniqueBlock
>
,
all_blocks
:
HashSet
<
UniqueBlock
>
,
move_block_response_tx
:
Option
<
mpsc
::
UnboundedSender
<
MoveBlockResponse
>>
,
}
}
impl
KvManager
{
impl
KvManager
{
pub
fn
new
(
max_capacity
:
usize
,
block_size
:
u32
)
->
Self
{
pub
fn
new
(
max_capacity
:
usize
,
block_size
:
usize
)
->
Self
{
Self
::
new_with_sender
(
max_capacity
,
block_size
,
None
)
}
pub
fn
new_with_sender
(
max_capacity
:
usize
,
block_size
:
usize
,
move_block_response_tx
:
Option
<
mpsc
::
UnboundedSender
<
MoveBlockResponse
>>
,
)
->
Self
{
let
active_blocks
=
HashMap
::
new
();
let
active_blocks
=
HashMap
::
new
();
let
inactive_blocks
=
LRUEvictor
::
default
();
let
inactive_blocks
=
LRUEvictor
::
default
();
let
all_blocks
=
HashSet
::
new
();
let
all_blocks
=
HashSet
::
new
();
...
@@ -78,18 +89,46 @@ impl KvManager {
...
@@ -78,18 +89,46 @@ impl KvManager {
active_blocks
,
active_blocks
,
inactive_blocks
,
inactive_blocks
,
all_blocks
,
all_blocks
,
move_block_response_tx
,
}
}
/// Utility method to send block responses with optional reversing
fn
send_block_response
(
&
self
,
mut
blocks
:
Vec
<
u64
>
,
reverse
:
bool
,
store
:
bool
,
parent_hash
:
Option
<
u64
>
,
)
{
if
let
Some
(
ref
tx
)
=
self
.move_block_response_tx
{
if
!
blocks
.is_empty
()
{
if
reverse
{
blocks
.reverse
();
}
let
response
=
if
store
{
MoveBlockResponse
::
Store
(
blocks
,
parent_hash
)
}
else
{
MoveBlockResponse
::
Remove
(
blocks
)
};
tx
.send
(
response
)
.unwrap
();
}
}
}
}
}
/// Process a MoveBlock instruction synchronously
/// Process a MoveBlock instruction synchronously
pub
fn
process
(
&
mut
self
,
event
:
&
MoveBlock
)
->
bool
{
pub
fn
process
(
&
mut
self
,
event
:
&
MoveBlock
)
->
bool
{
match
event
{
match
event
{
MoveBlock
::
Use
(
hashes
,
_
)
=>
{
MoveBlock
::
Use
(
hashes
)
=>
{
let
mut
blocks_stored
=
Vec
::
<
u64
>
::
new
();
let
mut
parent_block
:
Option
<&
UniqueBlock
>
=
None
;
for
hash
in
hashes
{
for
hash
in
hashes
{
// First check if it already exists in active blocks
// First check if it already exists in active blocks
if
let
Some
(
ref_count
)
=
self
.active_blocks
.get_mut
(
hash
)
{
if
let
Some
(
ref_count
)
=
self
.active_blocks
.get_mut
(
hash
)
{
// Block already active, just increment reference count
// Block already active, just increment reference count
*
ref_count
+=
1
;
*
ref_count
+=
1
;
parent_block
=
Some
(
hash
);
continue
;
continue
;
}
}
...
@@ -97,6 +136,7 @@ impl KvManager {
...
@@ -97,6 +136,7 @@ impl KvManager {
if
self
.inactive_blocks
.remove
(
hash
)
{
if
self
.inactive_blocks
.remove
(
hash
)
{
// Insert into active with reference count 1
// Insert into active with reference count 1
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
parent_block
=
Some
(
hash
);
continue
;
continue
;
}
}
...
@@ -106,30 +146,53 @@ impl KvManager {
...
@@ -106,30 +146,53 @@ impl KvManager {
// If at max capacity, evict the oldest entry from inactive blocks
// If at max capacity, evict the oldest entry from inactive blocks
if
active_count
+
inactive_count
>=
self
.max_capacity
{
if
active_count
+
inactive_count
>=
self
.max_capacity
{
if
let
Some
(
evicted
)
=
self
.inactive_blocks
.evict
()
{
let
Some
(
evicted
)
=
self
.inactive_blocks
.evict
()
else
{
// Remove evicted block from all_blocks
self
.all_blocks
.remove
(
&
evicted
);
}
else
{
// Cannot evict block, meaning no free blocks left in inactive pool
// Send a signal, scheduler would expect to handle preemption upon receiving this
return
false
;
return
false
;
};
self
.all_blocks
.remove
(
&
evicted
);
if
let
UniqueBlock
::
FullBlock
(
evicted_full_block
)
=
evicted
{
self
.send_block_response
(
vec!
[
evicted_full_block
],
false
,
false
,
None
);
}
}
}
}
// Now insert the new block in active blocks with reference count 1
// Now insert the new block in active blocks with reference count 1
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
// Add to all_blocks as it's a new block
self
.all_blocks
.insert
(
hash
.clone
());
self
.all_blocks
.insert
(
hash
.clone
());
if
self
.move_block_response_tx
.is_some
()
{
if
let
UniqueBlock
::
FullBlock
(
stored_full_block
)
=
hash
{
blocks_stored
.push
(
*
stored_full_block
);
}
}
}
}
}
let
parent_hash
=
match
parent_block
{
None
=>
None
,
Some
(
UniqueBlock
::
FullBlock
(
block
))
=>
Some
(
*
block
),
Some
(
UniqueBlock
::
PartialBlock
(
_
))
=>
panic!
(
"parent block cannot be partial"
),
};
self
.send_block_response
(
blocks_stored
,
false
,
true
,
parent_hash
);
}
MoveBlock
::
Destroy
(
hashes
)
=>
{
MoveBlock
::
Destroy
(
hashes
)
=>
{
let
mut
blocks_destroyed
=
Vec
::
<
u64
>
::
new
();
// Loop in inverse direction
// Loop in inverse direction
for
hash
in
hashes
.iter
()
.rev
()
{
for
hash
in
hashes
.iter
()
.rev
()
{
self
.active_blocks
.remove
(
hash
)
.unwrap
();
self
.active_blocks
.remove
(
hash
)
.unwrap
();
// Remove from all_blocks when destroyed
// Remove from all_blocks when destroyed
assert
!
(
self
.all_blocks
.remove
(
hash
));
assert
!
(
self
.all_blocks
.remove
(
hash
));
// Track blocks for batch sending
if
self
.move_block_response_tx
.is_some
()
{
if
let
UniqueBlock
::
FullBlock
(
destroyed_full_block
)
=
hash
{
blocks_destroyed
.push
(
*
destroyed_full_block
);
}
}
}
}
self
.send_block_response
(
blocks_destroyed
,
true
,
false
,
None
);
}
}
MoveBlock
::
Deref
(
hashes
)
=>
{
MoveBlock
::
Deref
(
hashes
)
=>
{
// Loop in inverse direction
// Loop in inverse direction
for
hash
in
hashes
.iter
()
.rev
()
{
for
hash
in
hashes
.iter
()
.rev
()
{
...
@@ -149,15 +212,15 @@ impl KvManager {
...
@@ -149,15 +212,15 @@ impl KvManager {
}
}
}
}
}
}
MoveBlock
::
Promote
(
uuid
,
hash
)
=>
{
MoveBlock
::
Promote
(
uuid
,
hash
,
parent_hash
)
=>
{
let
uuid_block
=
UniqueBlock
::
PartialBlock
(
*
uuid
);
let
uuid_block
=
UniqueBlock
::
PartialBlock
(
*
uuid
);
let
hash_block
=
UniqueBlock
::
FullBlock
(
*
hash
);
let
hash_block
=
UniqueBlock
::
FullBlock
(
*
hash
);
let
Some
(
ref_count
)
=
self
.active_blocks
.remove
(
&
uuid_block
)
else
{
let
Some
(
ref_count
)
=
self
.active_blocks
.remove
(
&
uuid_block
)
else
{
let
in_all_blocks
=
self
.all_blocks
.contains
(
&
uuid_block
);
let
in_all_blocks
=
self
.all_blocks
.contains
(
&
uuid_block
);
panic!
(
panic!
(
"Missing active block for promotion: {:?}. Block still exists: {}"
,
"Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}"
uuid_block
,
in_all_blocks
);
);
};
};
...
@@ -167,6 +230,7 @@ impl KvManager {
...
@@ -167,6 +230,7 @@ impl KvManager {
// Update all_blocks
// Update all_blocks
assert
!
(
self
.all_blocks
.remove
(
&
uuid_block
));
assert
!
(
self
.all_blocks
.remove
(
&
uuid_block
));
self
.all_blocks
.insert
(
hash_block
);
self
.all_blocks
.insert
(
hash_block
);
self
.send_block_response
(
vec!
[
*
hash
],
false
,
true
,
*
parent_hash
);
}
}
}
}
...
@@ -178,6 +242,7 @@ impl KvManager {
...
@@ -178,6 +242,7 @@ impl KvManager {
pub
fn
probe_new_blocks
(
&
self
,
blocks
:
&
[
UniqueBlock
])
->
usize
{
pub
fn
probe_new_blocks
(
&
self
,
blocks
:
&
[
UniqueBlock
])
->
usize
{
blocks
blocks
.iter
()
.iter
()
// .filter(|&block| !self.active_blocks.contains_key(block))
.filter
(|
&
block
|
!
self
.all_blocks
.contains
(
block
))
.filter
(|
&
block
|
!
self
.all_blocks
.contains
(
block
))
.count
()
.count
()
}
}
...
@@ -200,6 +265,11 @@ impl KvManager {
...
@@ -200,6 +265,11 @@ impl KvManager {
self
.active_blocks
.len
()
self
.active_blocks
.len
()
}
}
/// Get the percentage of active blocks relative to maximum capacity
pub
fn
get_active_perc
(
&
self
)
->
f64
{
self
.active_blocks
.len
()
as
f64
/
self
.max_capacity
as
f64
}
/// Get the number of inactive blocks
/// Get the number of inactive blocks
pub
fn
num_inactive_blocks
(
&
self
)
->
usize
{
pub
fn
num_inactive_blocks
(
&
self
)
->
usize
{
self
.inactive_blocks
.len
()
self
.inactive_blocks
.len
()
...
@@ -216,63 +286,28 @@ impl KvManager {
...
@@ -216,63 +286,28 @@ impl KvManager {
}
}
/// Check if a sequence can be scheduled and calculate cost if possible
/// Check if a sequence can be scheduled and calculate cost if possible
pub
fn
try_schedule
(
pub
fn
get_prefill_cost
(
&
self
,
sequence
:
&
ActiveSequence
)
->
PrefillCost
{
&
self
,
let
seq_blocks
=
sequence
.unique_blocks
();
sequence
:
&
ActiveSequence
,
let
new_blocks
=
self
.probe_new_blocks
(
seq_blocks
);
watermark
:
f64
,
let
overlap_blocks
=
seq_blocks
.len
()
-
new_blocks
;
tokens_budget
:
usize
,
let
new_tokens
=
sequence
.num_input_tokens
()
-
overlap_blocks
*
self
.block_size
;
)
->
Option
<
PrefillCost
>
{
// Return None immediately if tokens_budget is 0
if
tokens_budget
==
0
{
return
None
;
}
// Get unique blocks from the sequence
let
unique_blocks
=
sequence
.unique_blocks
();
// Get the count of new blocks
let
new_blocks
=
self
.probe_new_blocks
(
unique_blocks
);
// Calculate current usage and available capacity
let
active_count
=
self
.active_blocks
.len
();
// Check if we can schedule based on the watermark
if
(
active_count
+
new_blocks
)
as
f64
>
(
1.0
-
watermark
)
*
self
.max_capacity
as
f64
{
return
None
;
}
// Calculate overlap blocks
let
overlap_blocks
=
unique_blocks
.len
()
-
new_blocks
;
// Calculate new tokens
let
new_tokens
=
sequence
.num_input_tokens
()
-
overlap_blocks
*
(
self
.block_size
as
usize
);
// // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
// new_tokens,
// sequence.num_input_tokens(),
// overlap_blocks,
// self.block_size);
// Return None if new_tokens exceeds tokens_budget
if
new_tokens
>
tokens_budget
{
return
None
;
}
// Calculate prefill compute
// Calculate prefill compute
let
prefill_compute
=
let
prefill_compute
=
new_tokens
as
f64
*
(
new_tokens
+
overlap_blocks
*
(
self
.block_size
as
usize
))
as
f64
;
1.25e-6
*
(
new_tokens
as
f64
)
.powi
(
2
)
+
7.41e-2
*
(
new_tokens
as
f64
)
+
2.62e1
;
Some
(
PrefillCost
{
PrefillCost
{
new_blocks
,
new_tokens
,
new_tokens
,
prefill_compute
,
prefill_compute
,
}
)
}
}
}
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
use
tokio
::
sync
::
mpsc
;
#[test]
#[test]
fn
test_failure_on_max_capacity
()
{
fn
test_failure_on_max_capacity
()
{
...
@@ -282,7 +317,7 @@ mod tests {
...
@@ -282,7 +317,7 @@ mod tests {
// Helper function to use multiple blocks that returns the response
// Helper function to use multiple blocks that returns the response
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
->
bool
{
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
->
bool
{
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
manager
.process
(
&
MoveBlock
::
Use
(
blocks
,
None
))
manager
.process
(
&
MoveBlock
::
Use
(
blocks
))
}
}
// First use 10 blocks (0 to 9) in a batch
// First use 10 blocks (0 to 9) in a batch
...
@@ -301,15 +336,17 @@ mod tests {
...
@@ -301,15 +336,17 @@ mod tests {
}
}
#[test]
#[test]
// This is taken directly from the example in the vllm v1 prefix caching docs
fn
test_block_lifecycle_stringent
()
{
fn
test_block_lifecycle_stringent
()
{
// Create a KvManager with 10 blocks capacity
// Create a channel to listen to block responses
let
mut
manager
=
KvManager
::
new
(
10
,
16
);
let
(
tx
,
mut
rx
)
=
mpsc
::
unbounded_channel
::
<
MoveBlockResponse
>
();
// Create a KvManager with 10 blocks capacity and the response sender
let
mut
manager
=
KvManager
::
new_with_sender
(
10
,
16
,
Some
(
tx
));
// Helper function to use multiple blocks
// Helper function to use multiple blocks
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
{
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
{
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
manager
.process
(
&
MoveBlock
::
Use
(
blocks
,
None
));
manager
.process
(
&
MoveBlock
::
Use
(
blocks
));
}
}
// Helper function to destroy multiple blocks
// Helper function to destroy multiple blocks
...
@@ -324,6 +361,56 @@ mod tests {
...
@@ -324,6 +361,56 @@ mod tests {
manager
.process
(
&
MoveBlock
::
Deref
(
blocks
));
manager
.process
(
&
MoveBlock
::
Deref
(
blocks
));
}
}
// Helper function to assert block responses
fn
assert_block_response
(
rx
:
&
mut
mpsc
::
UnboundedReceiver
<
MoveBlockResponse
>
,
expected_type
:
&
str
,
expected_blocks
:
Vec
<
u64
>
,
description
:
&
str
,
)
{
let
response
=
rx
.try_recv
()
.unwrap_or_else
(|
_
|
panic!
(
"Expected {expected_type} response {description}"
));
match
(
&
response
,
expected_type
)
{
(
MoveBlockResponse
::
Store
(
blocks
,
_
parent_hash
),
"Store"
)
=>
{
assert_eq!
(
blocks
.len
(),
expected_blocks
.len
(),
"Expected {} blocks in Store response {}"
,
expected_blocks
.len
(),
description
);
assert_eq!
(
*
blocks
,
expected_blocks
,
"Store blocks don't match expected {description}"
);
}
(
MoveBlockResponse
::
Remove
(
blocks
),
"Remove"
)
=>
{
assert_eq!
(
blocks
.len
(),
expected_blocks
.len
(),
"Expected {} blocks in Remove response {}"
,
expected_blocks
.len
(),
description
);
assert_eq!
(
*
blocks
,
expected_blocks
,
"Remove blocks don't match expected {description}"
);
}
_
=>
panic!
(
"Expected {expected_type} response, got {response:?} {description}"
),
}
}
// Helper function to assert no response is received
fn
assert_no_response
(
rx
:
&
mut
mpsc
::
UnboundedReceiver
<
MoveBlockResponse
>
,
description
:
&
str
,
)
{
assert
!
(
rx
.try_recv
()
.is_err
(),
"Expected no response {description}"
,);
}
// Helper function to check if active blocks contain expected blocks with expected ref counts
// Helper function to check if active blocks contain expected blocks with expected ref counts
fn
assert_active_blocks
(
manager
:
&
KvManager
,
expected_blocks
:
&
[(
u64
,
usize
)])
{
fn
assert_active_blocks
(
manager
:
&
KvManager
,
expected_blocks
:
&
[(
u64
,
usize
)])
{
assert_eq!
(
assert_eq!
(
...
@@ -336,14 +423,12 @@ mod tests {
...
@@ -336,14 +423,12 @@ mod tests {
let
block
=
UniqueBlock
::
FullBlock
(
id
);
let
block
=
UniqueBlock
::
FullBlock
(
id
);
assert
!
(
assert
!
(
manager
.active_blocks
()
.contains_key
(
&
block
),
manager
.active_blocks
()
.contains_key
(
&
block
),
"Block {} not found in active blocks"
,
"Block {id} not found in active blocks"
,
id
);
);
assert_eq!
(
assert_eq!
(
manager
.active_blocks
()
.get
(
&
block
),
manager
.active_blocks
()
.get
(
&
block
),
Some
(
&
ref_count
),
Some
(
&
ref_count
),
"Block {} has wrong reference count"
,
"Block {id} has wrong reference count"
,
id
);
);
}
}
}
}
...
@@ -366,17 +451,18 @@ mod tests {
...
@@ -366,17 +451,18 @@ mod tests {
let
block
=
UniqueBlock
::
FullBlock
(
id
);
let
block
=
UniqueBlock
::
FullBlock
(
id
);
assert
!
(
assert
!
(
inactive_blocks
.iter
()
.any
(|
&
b
|
*
b
==
block
),
inactive_blocks
.iter
()
.any
(|
&
b
|
*
b
==
block
),
"Block {} not found in inactive blocks"
,
"Block {id} not found in inactive blocks"
,
id
);
);
}
}
}
}
// First use blocks 0, 1, 2, 3, 4 in a batch
// First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks
(
&
mut
manager
,
(
0
..
5
)
.collect
());
use_blocks
(
&
mut
manager
,
(
0
..
5
)
.collect
());
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
0
,
1
,
2
,
3
,
4
],
"after first use"
);
// Then use blocks 0, 1, 5, 6 in a batch
// Then use blocks 0, 1, 5, 6 in a batch
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
,
6
]);
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
,
6
]);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
5
,
6
],
"after second use"
);
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks
(
assert_active_blocks
(
...
@@ -386,9 +472,11 @@ mod tests {
...
@@ -386,9 +472,11 @@ mod tests {
// Now destroy block 4
// Now destroy block 4
destroy_blocks
(
&
mut
manager
,
vec!
[
4
]);
destroy_blocks
(
&
mut
manager
,
vec!
[
4
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
4
],
"after destroy block 4"
);
// And deref blocks 3, 2, 1, 0 in this order as a batch
// And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
3
]);
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
3
]);
assert_no_response
(
&
mut
rx
,
"after deref operation"
);
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
2
]);
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
2
]);
...
@@ -396,6 +484,7 @@ mod tests {
...
@@ -396,6 +484,7 @@ mod tests {
// Now destroy block 6
// Now destroy block 6
destroy_blocks
(
&
mut
manager
,
vec!
[
6
]);
destroy_blocks
(
&
mut
manager
,
vec!
[
6
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
6
],
"after block 6 eviction"
);
// And deref blocks 5, 1, 0 as a batch
// And deref blocks 5, 1, 0 as a batch
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
]);
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
]);
...
@@ -406,6 +495,7 @@ mod tests {
...
@@ -406,6 +495,7 @@ mod tests {
// Now use 0, 1, 2, 7, 8, 9 as a batch
// Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
7
,
8
,
9
]);
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
7
,
8
,
9
]);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
7
,
8
,
9
],
"after [7, 8, 9] use"
);
// Check that the inactive_blocks is size 2, and contains 3 and 5
// Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
5
]);
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
5
]);
...
@@ -420,8 +510,14 @@ mod tests {
...
@@ -420,8 +510,14 @@ mod tests {
// Now use blocks 10, 11, 12 as a batch
// Now use blocks 10, 11, 12 as a batch
use_blocks
(
&
mut
manager
,
vec!
[
10
,
11
,
12
]);
use_blocks
(
&
mut
manager
,
vec!
[
10
,
11
,
12
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
3
],
"after block 5 eviction"
);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
10
,
11
,
12
],
"after [10, 11, 12] use"
);
// Check that the inactive_blocks is size 1 and contains only 5
// Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks
(
&
manager
,
1
,
&
[
5
]);
assert_inactive_blocks
(
&
manager
,
1
,
&
[
5
]);
use_blocks
(
&
mut
manager
,
vec!
[
13
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
5
],
"after block 5 eviction"
);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
13
],
"after block 13 use"
);
}
}
}
}
lib/llm/src/mocker/protocols.rs
View file @
f0652d89
...
@@ -13,12 +13,16 @@
...
@@ -13,12 +13,16 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
derive_builder
::
Builder
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
use
crate
::
kv_router
::
protocols
::{
ExternalSequenceBlockHash
,
KvCacheEventData
,
KvCacheRemoveData
,
KvCacheStoreData
,
KvCacheStoredBlockData
,
LocalBlockHash
,
};
pub
type
Token
=
u32
;
pub
type
Token
=
u32
;
pub
type
LocalBlockHash
=
u64
;
/// A global hash identifier for blocks
pub
type
GlobalHash
=
u64
;
pub
type
GlobalHash
=
u64
;
pub
type
NumBlocks
=
usize
;
pub
type
NumBlocks
=
usize
;
...
@@ -39,12 +43,19 @@ impl Default for UniqueBlock {
...
@@ -39,12 +43,19 @@ impl Default for UniqueBlock {
}
}
/// Represents different block movement operations in the cache
/// Represents different block movement operations in the cache
/// For Use and Promote variants, parent hash is the second field
#[derive(Debug,
Clone,
PartialEq,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
PartialEq,
Serialize,
Deserialize)]
pub
enum
MoveBlock
{
pub
enum
MoveBlock
{
Use
(
Vec
<
UniqueBlock
>
,
Option
<
f64
>
),
Use
(
Vec
<
UniqueBlock
>
),
Destroy
(
Vec
<
UniqueBlock
>
),
Destroy
(
Vec
<
UniqueBlock
>
),
Deref
(
Vec
<
UniqueBlock
>
),
Deref
(
Vec
<
UniqueBlock
>
),
Promote
(
Uuid
,
GlobalHash
),
Promote
(
Uuid
,
GlobalHash
,
Option
<
u64
>
),
}
#[derive(Debug,
Clone,
PartialEq,
Serialize,
Deserialize)]
pub
enum
MoveBlockResponse
{
Store
(
Vec
<
GlobalHash
>
,
Option
<
u64
>
),
Remove
(
Vec
<
GlobalHash
>
),
}
}
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
...
@@ -52,15 +63,86 @@ pub struct DirectRequest {
...
@@ -52,15 +63,86 @@ pub struct DirectRequest {
pub
tokens
:
Vec
<
Token
>
,
pub
tokens
:
Vec
<
Token
>
,
pub
max_output_tokens
:
usize
,
pub
max_output_tokens
:
usize
,
pub
uuid
:
Option
<
Uuid
>
,
pub
uuid
:
Option
<
Uuid
>
,
pub
dp_rank
:
Option
<
u32
>
,
}
}
/// Represents the cost of prefilling content in the cache
/// Represents the cost of prefilling content in the cache
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
PrefillCost
{
pub
struct
PrefillCost
{
pub
new_blocks
:
usize
,
pub
new_tokens
:
usize
,
pub
new_tokens
:
usize
,
pub
prefill_compute
:
f64
,
pub
prefill_compute
:
f64
,
}
}
/// Signal for output token generation with completion status
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
OutputSignal
{
pub
uuid
:
Uuid
,
pub
completed
:
bool
,
}
/// Configuration arguments for MockVllmEngine
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Builder)]
#[builder(pattern
=
"owned"
,
build_fn(public))]
pub
struct
MockEngineArgs
{
#[builder(default
=
"16384"
)]
pub
num_gpu_blocks
:
usize
,
#[builder(default
=
"64"
)]
pub
block_size
:
usize
,
// This was 1024 in the past but reverted back to 256
#[builder(default
=
Some(
256
))]
pub
max_num_seqs
:
Option
<
usize
>
,
// default for open api server, for llm class it's 16384
#[builder(default
=
Some(
8192
))]
pub
max_num_batched_tokens
:
Option
<
usize
>
,
#[builder(default
=
true
)]
pub
enable_prefix_caching
:
bool
,
#[builder(default
=
"0.01"
)]
pub
watermark
:
f64
,
#[builder(default
=
"1.0"
)]
pub
speedup_ratio
:
f64
,
#[builder(default
=
"1"
)]
pub
dp_size
:
u32
,
}
impl
MockEngineArgs
{
pub
fn
builder
()
->
MockEngineArgsBuilder
{
MockEngineArgsBuilder
::
default
()
}
}
/// Note: This assumes block_hash and tokens_hash are the same, which is not correct in rare cases
/// where the sequence-aware hash differs from the token content hash.
pub
fn
block_response_to_kv_event
(
response
:
MoveBlockResponse
)
->
KvCacheEventData
{
match
response
{
MoveBlockResponse
::
Store
(
full_blocks
,
parent_hash
)
=>
{
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
parent_hash
.map
(
ExternalSequenceBlockHash
),
blocks
:
full_blocks
.into_iter
()
.map
(|
block
|
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
block
),
tokens_hash
:
LocalBlockHash
(
block
),
})
.collect
(),
})
}
MoveBlockResponse
::
Remove
(
full_blocks
)
=>
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
full_blocks
.into_iter
()
.map
(
ExternalSequenceBlockHash
)
.collect
(),
}),
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
...
...
lib/llm/src/mocker/scheduler.rs
View file @
f0652d89
...
@@ -40,11 +40,13 @@
...
@@ -40,11 +40,13 @@
//! ## NOTE
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
crate
::
kv_router
::
protocols
::
{
ForwardPassMetrics
,
KvCacheEventData
}
;
use
crate
::
mocker
::
evictor
::
LRUEvictor
;
use
crate
::
mocker
::
evictor
::
LRUEvictor
;
use
crate
::
mocker
::
kv_manager
::
KvManager
;
use
crate
::
mocker
::
kv_manager
::
KvManager
;
use
crate
::
mocker
::
protocols
::
DirectRequest
;
use
crate
::
mocker
::
protocols
::{
use
crate
::
mocker
::
protocols
::{
MoveBlock
,
PrefillCost
,
UniqueBlock
};
block_response_to_kv_event
,
MoveBlock
,
OutputSignal
,
PrefillCost
,
UniqueBlock
,
};
use
crate
::
mocker
::
protocols
::{
DirectRequest
,
MockEngineArgs
,
MoveBlockResponse
};
use
crate
::
mocker
::
sequence
::
ActiveSequence
;
use
crate
::
mocker
::
sequence
::
ActiveSequence
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
VecDeque
;
use
std
::
collections
::
VecDeque
;
...
@@ -63,8 +65,8 @@ pub enum Request {
...
@@ -63,8 +65,8 @@ pub enum Request {
#[derive(Default)]
#[derive(Default)]
struct
SchedulerState
{
struct
SchedulerState
{
waiting
:
VecDeque
<
Uuid
>
,
waiting
:
VecDeque
<
Uuid
>
,
re
ady
:
VecDeque
<
Uuid
>
,
p
re
fill
:
VecDeque
<
Uuid
>
,
running
:
LRUEvictor
<
Uuid
>
,
decode
:
LRUEvictor
<
Uuid
>
,
requests
:
HashMap
<
Uuid
,
Request
>
,
requests
:
HashMap
<
Uuid
,
Request
>
,
prefill_costs
:
HashMap
<
Uuid
,
Option
<
PrefillCost
>>
,
prefill_costs
:
HashMap
<
Uuid
,
Option
<
PrefillCost
>>
,
}
}
...
@@ -74,61 +76,70 @@ impl SchedulerState {
...
@@ -74,61 +76,70 @@ impl SchedulerState {
fn
receive
(
&
mut
self
,
request
:
DirectRequest
)
->
Uuid
{
fn
receive
(
&
mut
self
,
request
:
DirectRequest
)
->
Uuid
{
// Use the provided UUID if available, otherwise generate a new one
// Use the provided UUID if available, otherwise generate a new one
let
uuid
=
request
.uuid
.unwrap_or_else
(
Uuid
::
new_v4
);
let
uuid
=
request
.uuid
.unwrap_or_else
(
Uuid
::
new_v4
);
// Add the request to the map and waiting queue
self
.requests
.insert
(
uuid
,
Request
::
Direct
(
request
));
self
.requests
.insert
(
uuid
,
Request
::
Direct
(
request
));
self
.waiting
.push_back
(
uuid
);
self
.waiting
.push_back
(
uuid
);
uuid
uuid
}
}
/// Get the next UUID from ready or waiting queue and its associated Request.
/// Get the next UUID from ready or waiting queue and its associated Request.
/// Returns from ready if not empty, otherwise from waiting, or None if both are empty.
/// Also removes the Request from the requests HashMap.
fn
next
(
&
mut
self
)
->
Option
<
(
Uuid
,
Request
)
>
{
fn
next
(
&
mut
self
)
->
Option
<
(
Uuid
,
Request
)
>
{
let
uuid
=
self
let
uuid
=
self
.waiting
.pop_front
()
?
;
.ready
let
request
=
self
.
pop_front
()
.
requests
.
or_else
(||
self
.waiting
.pop_front
())
?
;
.
remove
(
&
uuid
)
let
request
=
self
.requests
.remove
(
&
uuid
)
?
;
.expect
(
"Request does not exist."
)
;
Some
((
uuid
,
request
))
Some
((
uuid
,
request
))
}
}
/// Move a UUID and its Request to the waiting queue (front).
fn
first_in_line
(
&
mut
self
,
uuid
:
Uuid
,
request
:
Request
)
{
self
.requests
.insert
(
uuid
,
request
);
self
.waiting
.push_front
(
uuid
);
}
/// Move a UUID and its Request to the ready queue.
/// Move a UUID and its Request to the ready queue.
fn
make_ready
(
&
mut
self
,
uuid
:
Uuid
,
active_seq
:
ActiveSequence
)
{
fn
start_prefill
(
&
mut
self
,
uuid
:
Uuid
,
active_seq
:
ActiveSequence
,
cost
:
Option
<
PrefillCost
>
)
{
self
.requests
.insert
(
uuid
,
Request
::
Active
(
active_seq
));
self
.requests
.insert
(
uuid
,
Request
::
Active
(
active_seq
));
self
.ready
.push_back
(
uuid
);
self
.prefill
.push_back
(
uuid
);
self
.prefill_costs
.insert
(
uuid
,
cost
);
}
}
/// Schedule the request with the given UUID.
/// Pop from prefill queue and move to decode queue.
/// Returns the creation signal from the ActiveSequence.
/// Returns the prefill_compute value if available.
fn
run
(
&
mut
self
,
uuid
:
Uuid
,
active_seq
:
ActiveSequence
)
->
MoveBlock
{
fn
start_decode
(
&
mut
self
)
->
Option
<
(
f64
,
MoveBlock
)
>
{
// Insert the request into the map
let
uuid
=
self
.prefill
.pop_front
()
?
;
self
.requests
.insert
(
uuid
,
Request
::
Active
(
active_seq
));
self
.decode
.insert
(
uuid
);
// Remove and extract prefill_compute from prefill_costs
let
prefill_cost
=
self
.prefill_costs
.remove
(
&
uuid
)
.flatten
()
.expect
(
"Expects valid prefill cost."
);
// Get the creation signal
let
Some
(
Request
::
Active
(
sequence
))
=
self
.requests
.get
(
&
uuid
)
else
{
let
Some
(
Request
::
Active
(
sequence
))
=
self
.requests
.get
(
&
uuid
)
else
{
panic!
(
"Failed to get ActiveSequence for UUID"
);
panic!
(
"Request does not exist."
);
};
let
Some
(
signal
)
=
sequence
.creation_signal
()
else
{
panic!
(
"Failed to get creation signal from ActiveSequence"
);
};
};
let
creation_signal
=
sequence
.creation_signal
()
.clone
()
.expect
(
"Must have creation signal."
);
// Add to running requests
Some
((
prefill_cost
.prefill_compute
,
creation_signal
))
self
.running
.insert
(
uuid
);
signal
.clone
()
}
}
/// Set the prefill cost for a UUID
fn
run
(
&
mut
self
,
uuid
:
Uuid
)
->
Option
<&
mut
ActiveSequence
>
{
fn
set_prefill_cost
(
&
mut
self
,
uuid
:
Uuid
,
cost
:
Option
<
PrefillCost
>
)
{
if
!
self
.decode
.contains
(
&
uuid
)
{
self
.prefill_costs
.insert
(
uuid
,
cost
);
return
None
;
}
let
Some
(
Request
::
Active
(
sequence
))
=
self
.requests
.get_mut
(
&
uuid
)
else
{
panic!
(
"Request does not exist."
);
};
Some
(
sequence
)
}
}
/// Get the prefill compute value for a UUID if available
fn
num_active_requests
(
&
self
)
->
usize
{
fn
get_prefill_compute
(
&
self
,
uuid
:
&
Uuid
)
->
Option
<
f64
>
{
self
.prefill
.len
()
+
self
.decode
.len
()
self
.prefill_costs
.get
(
uuid
)
.and_then
(|
cost
|
cost
.as_ref
())
.map
(|
cost
|
cost
.prefill_compute
)
}
}
/// Calculate the current running batched tokens
/// Calculate the current running batched tokens
...
@@ -145,7 +156,7 @@ impl SchedulerState {
...
@@ -145,7 +156,7 @@ impl SchedulerState {
/// Remove a UUID and its associated Request from collections.
/// Remove a UUID and its associated Request from collections.
fn
complete
(
&
mut
self
,
uuid
:
&
Uuid
)
{
fn
complete
(
&
mut
self
,
uuid
:
&
Uuid
)
{
// println!("Request {} will complete", uuid);
// println!("Request {} will complete", uuid);
self
.
running
.remove
(
uuid
);
self
.
decode
.remove
(
uuid
);
self
.requests
.remove
(
uuid
);
self
.requests
.remove
(
uuid
);
self
.prefill_costs
.remove
(
uuid
);
self
.prefill_costs
.remove
(
uuid
);
}
}
...
@@ -153,76 +164,93 @@ impl SchedulerState {
...
@@ -153,76 +164,93 @@ impl SchedulerState {
/// Preempt the oldest running request by evicting it from running, resetting the sequence,
/// Preempt the oldest running request by evicting it from running, resetting the sequence,
/// and adding it back to the waiting queue.
/// and adding it back to the waiting queue.
/// Returns the signal from reset_with_signal or None if no requests are running.
/// Returns the signal from reset_with_signal or None if no requests are running.
fn
preempt
(
&
mut
self
)
->
Option
<
Vec
<
MoveBlock
>
>
{
fn
preempt
(
&
mut
self
)
->
Vec
<
MoveBlock
>
{
// Evict the oldest UUID from running
// Evict the oldest UUID from running
let
uuid
=
self
.running
.evict
()
?
;
let
uuid
=
self
eprintln!
(
"Request {} will be preempted"
,
uuid
);
.decode
.evict
()
// Remove the request from the requests HashMap and ensure it's an ActiveSequence
.expect
(
"Nothing to evict for preemption."
);
let
request
=
self
.requests
.remove
(
&
uuid
)
?
;
let
request
=
self
.requests
// Remove the prefill cost to force recomputation
.remove
(
&
uuid
)
.expect
(
"Request does not exist."
);
self
.prefill_costs
.remove
(
&
uuid
);
self
.prefill_costs
.remove
(
&
uuid
);
eprintln!
(
"Request {uuid} will be preempted"
);
// Extract the ActiveSequence from the Request enum
// Reset the sequence and get the new sequence and signal
// Insert the new sequence back into the requests map and add to waiting queue
let
Request
::
Active
(
mut
active_sequence
)
=
request
else
{
let
Request
::
Active
(
mut
active_sequence
)
=
request
else
{
panic!
(
"Expected ActiveSequence in running queue"
)
panic!
(
"Expected ActiveSequence in running queue"
)
};
};
// Reset the sequence and get the new sequence and signal
let
signals
=
active_sequence
.reset_with_signal
();
let
signals
=
active_sequence
.reset_with_signal
();
// Insert the new sequence back into the requests map and add to waiting queue
// Note: For preemption, we don't compute hit rate since we don't have access to new_tokens
self
.requests
.insert
(
uuid
,
Request
::
Active
(
active_sequence
));
// and the sequence is being reset anyway. Hit rate tracking is primarily for new scheduling attempts.
self
.waiting
.push_back
(
uuid
);
self
.first_in_line
(
uuid
,
Request
::
Active
(
active_sequence
));
Some
(
signals
)
signals
}
}
}
}
/// Manages scheduling of requests using KvManager resources
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
#[derive(Clone)]
pub
struct
Scheduler
{
pub
struct
Scheduler
{
dp_rank
:
Option
<
u32
>
,
state
:
Arc
<
Mutex
<
SchedulerState
>>
,
state
:
Arc
<
Mutex
<
SchedulerState
>>
,
kv_manager
:
Arc
<
Mutex
<
KvManager
>>
,
kv_manager
:
Arc
<
Mutex
<
KvManager
>>
,
request_tx
:
mpsc
::
Sender
<
DirectRequest
>
,
request_tx
:
mpsc
::
UnboundedSender
<
DirectRequest
>
,
hit_rates
:
Arc
<
Mutex
<
VecDeque
<
f32
>>>
,
}
}
impl
Scheduler
{
impl
Scheduler
{
/// Create a new Scheduler with the given parameters
/// Create a new Scheduler with the given parameters
pub
fn
new
(
pub
fn
new
(
kv_capacity
:
usize
,
args
:
MockEngineArgs
,
watermark
:
f64
,
dp_rank
:
Option
<
u32
>
,
block_size
:
u32
,
output_tx
:
Option
<
mpsc
::
UnboundedSender
<
OutputSignal
>>
,
chunk_size
:
Option
<
usize
>
,
kv_events_tx
:
Option
<
mpsc
::
UnboundedSender
<
KvCacheEventData
>>
,
output_tx
:
Option
<
mpsc
::
Sender
<
Uuid
>>
,
cancellation_token
:
Option
<
CancellationToken
>
,
cancellation_token
:
Option
<
CancellationToken
>
,
)
->
Self
{
)
->
Self
{
// Create KvManager internally
let
kv_manager
=
KvManager
::
new
(
kv_capacity
,
block_size
);
let
token_capacity
:
usize
=
8192
;
let
state
=
Arc
::
new
(
Mutex
::
new
(
SchedulerState
::
default
()));
let
state
=
Arc
::
new
(
Mutex
::
new
(
SchedulerState
::
default
()));
let
kv_manager
=
Arc
::
new
(
Mutex
::
new
(
kv_manager
));
// Create internal channel for KV events only if needed
let
chunk_size
=
chunk_size
.unwrap_or
(
256
);
let
(
block_resp_tx
,
mut
block_resp_rx
)
=
if
kv_events_tx
.is_some
()
{
let
(
tx
,
rx
)
=
mpsc
::
unbounded_channel
::
<
MoveBlockResponse
>
();
(
Some
(
tx
),
Some
(
rx
))
}
else
{
(
None
,
None
)
};
let
kv_manager
=
Arc
::
new
(
Mutex
::
new
(
KvManager
::
new_with_sender
(
args
.num_gpu_blocks
,
args
.block_size
,
block_resp_tx
,
)));
let
hit_rates
=
Arc
::
new
(
Mutex
::
new
(
VecDeque
::
with_capacity
(
1000
)));
// Create channel for request handling
// Assert speedup_ratio is greater than 0
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
channel
::
<
DirectRequest
>
(
1024
);
assert
!
(
args
.speedup_ratio
>
0.0
,
"speedup_ratio must be greater than 0, got: {}"
,
args
.speedup_ratio
);
// Use provided cancellation token or create new one
// Create channel for request handling
let
cancellation_token
=
cancellation_token
.unwrap_or_default
();
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
DirectRequest
>
();
let
token_clone
=
cancellation_token
.clone
();
// Create a clone for the background task
// Create a clone for the background task
let
state_clone
=
state
.clone
();
let
state_clone
=
state
.clone
();
let
kv_manager_clone
=
kv_manager
.clone
();
let
kv_manager_clone
=
kv_manager
.clone
();
let
output_tx_clone
=
output_tx
.clone
();
let
output_tx_clone
=
output_tx
.clone
();
let
cancel_token_clone
=
cancellation_token
.unwrap_or_default
()
.clone
();
let
hit_rates_clone
=
hit_rates
.clone
();
// Spawn main background task with cancellation token
// Spawn main background task with cancellation token
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
let
mut
schedule_interval
=
interval
(
Duration
::
from_millis
(
5
));
let
mut
schedule_interval
=
interval
(
Duration
::
from_secs_f64
(
1e-3
));
let
mut
simulate_interval
=
interval
(
Duration
::
from_millis
(
1
));
let
mut
simulate_interval
=
interval
(
Duration
::
from_secs_f64
(
1e-4
));
let
mut
should_schedule
=
true
;
loop
{
loop
{
tokio
::
select!
{
tokio
::
select!
{
...
@@ -234,35 +262,63 @@ impl Scheduler {
...
@@ -234,35 +262,63 @@ impl Scheduler {
state
.receive
(
request
);
state
.receive
(
request
);
}
}
// Try Scheduling Requests
// Try Scheduling Requests
- runs on normal interval or after simulation
_
=
schedule_interval
.tick
()
=>
{
_
=
schedule_interval
.tick
()
=>
{
// Skip if we just ran scheduling after simulation to prevent consecutive runs
if
!
should_schedule
{
continue
;
}
let
mut
state_guard
=
state_clone
.lock
()
.await
;
let
mut
state_guard
=
state_clone
.lock
()
.await
;
let
mut
kv_manager_guard
=
kv_manager_clone
.lock
()
.await
;
let
kv_manager_guard
=
kv_manager_clone
.lock
()
.await
;
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore.
// schedule anymore.
let
mut
current_blocks
=
kv_manager_guard
.num_active_blocks
();
let
mut
current_tokens
=
state_guard
.num_batched_tokens
();
let
mut
current_seqs
=
state_guard
.num_active_requests
();
while
let
Some
((
uuid
,
request
))
=
state_guard
.next
()
{
while
let
Some
((
uuid
,
request
))
=
state_guard
.next
()
{
let
active_sequence
=
get_active_sequence
(
request
,
block_size
,
chunk_size
);
let
active_sequence
=
get_active_sequence
(
request
,
args
.
block_size
,
args
.enable_prefix_caching
);
// Calculate token budget using new_tokens from PrefillCost
// Update predictive budgets
let
total_prefill_tokens
=
state_guard
.num_batched_tokens
();
let
prefill_cost
=
kv_manager_guard
.get_prefill_cost
(
&
active_sequence
);
let
tokens_budget
=
token_capacity
.saturating_sub
(
total_prefill_tokens
);
let
total_tokens
=
active_sequence
.len
();
let
new_blocks
=
(
total_tokens
+
1
)
/
args
.block_size
;
// this is conservative, assumes no cache hit
let
new_tokens
=
prefill_cost
.new_tokens
;
current_blocks
+=
new_blocks
;
current_tokens
+=
new_tokens
;
current_seqs
+=
1
;
// Check if it can be scheduled
// Check if it can be scheduled
let
Some
(
prefill_cost
)
=
kv_manager_guard
.try_schedule
(
&
active_sequence
,
watermark
,
tokens_budget
)
else
{
let
under_block_budget
=
current_blocks
as
f64
<=
(
1
.
-
args
.watermark
)
*
kv_manager_guard
.max_capacity
()
as
f64
;
state_guard
.make_ready
(
uuid
,
active_sequence
);
let
under_token_budget
=
args
.max_num_batched_tokens
.is_none_or
(|
limit
|
current_tokens
<=
limit
);
let
under_seq_budget
=
args
.max_num_seqs
.is_none_or
(|
limit
|
current_seqs
<=
limit
);
// Cannot schedule, put first in line instead
if
!
(
under_block_budget
&&
under_token_budget
&&
under_seq_budget
)
{
state_guard
.first_in_line
(
uuid
,
Request
::
Active
(
active_sequence
));
break
;
break
;
}
;
}
// Get creation signal and schedule the request
// Compute and store hit rate
let
signal
=
state_guard
.run
(
uuid
,
active_sequence
);
let
hit_rate
=
if
!
active_sequence
.is_empty
()
{
1.0
-
(
new_tokens
as
f32
/
active_sequence
.len
()
as
f32
)
}
else
{
0.0
};
kv_manager_guard
.process
(
&
signal
);
{
state_guard
.set_prefill_cost
(
uuid
,
Some
(
prefill_cost
));
let
mut
hit_rates_guard
=
hit_rates_clone
.lock
()
.await
;
hit_rates_guard
.push_back
(
hit_rate
);
if
hit_rates_guard
.len
()
>
1000
{
hit_rates_guard
.pop_front
();
}
}
state_guard
.start_prefill
(
uuid
,
active_sequence
,
Some
(
prefill_cost
));
should_schedule
=
false
;
}
}
}
}
// Check for cancellation
// Check for cancellation
_
=
token_clone
.cancelled
()
=>
{
_
=
cancel_
token_clone
.cancelled
()
=>
{
break
;
break
;
}
}
...
@@ -271,75 +327,84 @@ impl Scheduler {
...
@@ -271,75 +327,84 @@ impl Scheduler {
let
mut
state_guard
=
state_clone
.lock
()
.await
;
let
mut
state_guard
=
state_clone
.lock
()
.await
;
let
mut
kv_manager_guard
=
kv_manager_clone
.lock
()
.await
;
let
mut
kv_manager_guard
=
kv_manager_clone
.lock
()
.await
;
// Base time needed for decoding
(assumed memory bound on KV cache)
// Base time needed for decoding
using active percentage and quadratic formula
let
active_
tokens
=
kv_manager_guard
.
num
_active_
blocks
()
*
(
block_size
as
usize
);
let
active_
perc
=
kv_manager_guard
.
get
_active_
perc
(
);
// TODO: 2 is a dummy / magic scaling factor
let
decoding_time
=
-
5.47
*
active_perc
.powi
(
2
)
+
43.88
*
active_perc
+
19.44
;
let
mut
generation
_time
=
Duration
::
from_
micros
((
active_tokens
/
2
)
as
u64
);
let
mut
total
_time
=
Duration
::
from_
secs_f64
(
decoding_time
/
1000.0
);
// Process each running request
// Process prefilling
let
uuids
:
Vec
<
Uuid
>
=
state_guard
.running
.keys
()
.cloned
()
.collect
();
while
let
Some
((
prefill_compute
,
creation_signal
))
=
state_guard
.start_decode
()
{
for
uuid
in
uuids
{
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// Check if UUID is still in running_requests, if not skip this iteration
// could be cached by other requests in the same batch. This matches vLLM behavior.
if
!
state_guard
.running
.contains
(
&
uuid
)
{
total_time
+=
Duration
::
from_secs_f64
(
prefill_compute
/
1000.0
);
continue
;
let
prefill_success
=
process_signals
(
&
mut
kv_manager_guard
,
std
::
slice
::
from_ref
(
&
creation_signal
));
if
!
prefill_success
{
panic!
(
"Block allocation for prefilling cannot fail."
);
}
}
//
Get prefill compute value first
//
Drain KV events and forward to relay after prefill signal processing
let
prefill_compute
=
state_guard
.get_prefill_compute
(
&
uuid
);
if
let
(
Some
(
ref
relay_tx
),
Some
(
ref
mut
rx
))
=
(
&
kv_events_tx
,
&
mut
block_resp_rx
)
{
while
let
Ok
(
event
)
=
rx
.try_recv
()
{
// Get the active sequence for this UUID
let
_
=
relay_tx
.send
(
block_response_to_kv_event
(
event
));
let
sequence
=
state_guard
.requests
.get_mut
(
&
uuid
)
}
.and_then
(|
req
|
if
let
Request
::
Active
(
seq
)
=
req
{
Some
(
seq
)
}
else
{
None
})
}
.expect
(
"UUID in running_requests must have a corresponding active sequence"
);
}
// Generate token and get signals
// Process decoding
let
uuids
:
Vec
<
Uuid
>
=
state_guard
.decode
.keys
()
.cloned
()
.collect
();
if
!
uuids
.is_empty
()
{
should_schedule
=
true
};
for
uuid
in
uuids
{
let
Some
(
sequence
)
=
state_guard
.run
(
uuid
)
else
{
continue
;
};
let
signals
=
sequence
.generate
();
let
signals
=
sequence
.generate
();
// Accumulate sleep duration based on prefill_compute if available
// prefill compute = (cached_tokens + new_tokens) * new_tokens
let
sleep_ms
=
if
let
Some
(
compute
)
=
prefill_compute
{
// TODO: 1024 is a dummy / magic scaling factor
(
compute
/
1024.0
)
as
u64
}
else
{
0
};
generation_time
+=
Duration
::
from_micros
(
sleep_ms
);
// Process all signals with the KvManager
// Process all signals with the KvManager
// Handling of preemption on failure
// Handling of preemption on failure
if
!
process_signals
(
&
mut
kv_manager_guard
,
&
signals
)
{
if
!
process_signals
(
&
mut
kv_manager_guard
,
&
signals
)
{
sequence
.pop
();
// revert the failed generation op
sequence
.pop
();
// revert the failed generation op
for
signal
in
state_guard
.preempt
()
{
// free_signal derefs the preempted blocks
let
Some
(
free_signal
)
=
state_guard
.preempt
()
else
{
panic!
(
"Failed to acquire signal to free KV blocks from preemption"
);
};
for
signal
in
free_signal
{
kv_manager_guard
.process
(
&
signal
);
kv_manager_guard
.process
(
&
signal
);
}
}
continue
;
continue
;
}
}
// Send UUID notification for each generated token
// Drain KV events and forward to relay after decode signal processing
// TODO: hook this up to an AsyncEngine
if
let
(
Some
(
ref
relay_tx
),
Some
(
ref
mut
rx
))
=
(
&
kv_events_tx
,
&
mut
block_resp_rx
)
{
if
let
Some
(
tx
)
=
&
output_tx_clone
{
while
let
Ok
(
event
)
=
rx
.try_recv
()
{
let
_
=
tx
.try_send
(
uuid
);
let
_
=
relay_tx
.send
(
block_response_to_kv_event
(
event
));
}
}
}
// Check if we're done after generating
// Check completion and send notification
if
sequence
.generated_tokens
()
>=
sequence
.max_output_tokens
()
{
let
is_complete
=
sequence
.generated_tokens
()
>=
sequence
.max_output_tokens
();
state_guard
.complete
(
&
uuid
);
let
should_output
=
sequence
.generated_tokens
()
>
sequence
.already_generated_tokens
();
continue
;
let
mut
send_failed
=
false
;
if
should_output
{
send_failed
=
output_tx_clone
.as_ref
()
.is_some_and
(|
tx
|
{
tx
.send
(
OutputSignal
{
uuid
,
completed
:
is_complete
})
.is_err
()
});
}
if
send_failed
{
for
signal
in
&
sequence
.free_signal
()
{
kv_manager_guard
.process
(
signal
);
}
}
}
// Transition to decode (no prefill cost)
if
send_failed
||
is_complete
{
if
sequence
.generated_tokens
()
==
1
{
state_guard
.complete
(
&
uuid
);
state_guard
.set_prefill_cost
(
uuid
,
None
)
;
continue
;
}
}
}
}
// Sleep once for the accumulated duration
// Sleep once for the adjusted duration
if
generation_time
.as_millis
()
>
0
{
drop
(
kv_manager_guard
);
tokio
::
time
::
sleep
(
generation_time
)
.await
;
drop
(
state_guard
);
let
adjusted_time
=
Duration
::
from_secs_f64
(
total_time
.as_secs_f64
()
/
args
.speedup_ratio
);
if
adjusted_time
.as_millis
()
>
0
{
tokio
::
time
::
sleep
(
adjusted_time
)
.await
;
}
}
}
}
}
}
...
@@ -347,15 +412,22 @@ impl Scheduler {
...
@@ -347,15 +412,22 @@ impl Scheduler {
});
});
Self
{
Self
{
dp_rank
,
state
,
state
,
kv_manager
,
kv_manager
,
request_tx
,
request_tx
,
hit_rates
,
}
}
}
}
/// Add a new request to the waiting queue
/// Add a new request to the waiting queue
pub
async
fn
receive
(
&
self
,
request
:
DirectRequest
)
{
pub
async
fn
receive
(
&
self
,
request
:
DirectRequest
)
{
let
_
=
self
.request_tx
.send
(
request
)
.await
;
let
_
=
self
.request_tx
.send
(
request
);
}
/// Expose the sender
pub
fn
request_sender
(
&
self
)
->
mpsc
::
UnboundedSender
<
DirectRequest
>
{
self
.request_tx
.clone
()
}
}
/// Get the count of waiting requests
/// Get the count of waiting requests
...
@@ -367,7 +439,7 @@ impl Scheduler {
...
@@ -367,7 +439,7 @@ impl Scheduler {
/// Get the count of running requests
/// Get the count of running requests
pub
async
fn
running_count
(
&
self
)
->
usize
{
pub
async
fn
running_count
(
&
self
)
->
usize
{
let
state
=
self
.state
.lock
()
.await
;
let
state
=
self
.state
.lock
()
.await
;
state
.
running
.len
()
state
.
decode
.len
()
}
}
/// Get the current capacity of the KvManager
/// Get the current capacity of the KvManager
...
@@ -378,35 +450,53 @@ impl Scheduler {
...
@@ -378,35 +450,53 @@ impl Scheduler {
/// Returns forward pass metrics for monitoring purposes
/// Returns forward pass metrics for monitoring purposes
pub
async
fn
get_forward_pass_metrics
(
&
self
)
->
ForwardPassMetrics
{
pub
async
fn
get_forward_pass_metrics
(
&
self
)
->
ForwardPassMetrics
{
// Acquire all locks in consistent order: state -> kv_manager -> hit_rates
let
state
=
self
.state
.lock
()
.await
;
let
state
=
self
.state
.lock
()
.await
;
let
kv_manager
=
self
.kv_manager
.lock
()
.await
;
let
kv_manager
=
self
.kv_manager
.lock
()
.await
;
let
hit_rates_guard
=
self
.hit_rates
.lock
()
.await
;
// Get the active blocks and total capacity from KvManager
// Get state metrics
let
request_active_slots
=
state
.decode
.len
()
as
u64
;
let
num_requests_waiting
=
state
.waiting
.len
()
as
u64
;
// Get KV manager metrics
let
active_blocks_count
=
kv_manager
.active_blocks
()
.len
()
as
u64
;
let
active_blocks_count
=
kv_manager
.active_blocks
()
.len
()
as
u64
;
let
total_capacity
=
kv_manager
.max_capacity
()
as
u64
;
let
total_capacity
=
kv_manager
.max_capacity
()
as
u64
;
// Calculate GPU cache usage percentage
let
gpu_cache_usage_perc
=
if
total_capacity
>
0
{
let
gpu_cache_usage_perc
=
if
total_capacity
>
0
{
active_blocks_count
as
f32
/
total_capacity
as
f32
active_blocks_count
as
f32
/
total_capacity
as
f32
}
else
{
}
else
{
0.0
0.0
};
};
// Get hit rate metrics
let
gpu_prefix_cache_hit_rate
=
if
hit_rates_guard
.is_empty
()
{
0.0
}
else
{
let
sum
:
f32
=
hit_rates_guard
.iter
()
.sum
();
sum
/
hit_rates_guard
.len
()
as
f32
};
ForwardPassMetrics
{
ForwardPassMetrics
{
data_parallel_rank
:
None
,
// Default for backwards compatibility
data_parallel_rank
:
self
.dp_rank
,
request_active_slots
:
state
.running
.len
()
as
u64
,
request_active_slots
,
request_total_slots
:
420
,
// Dummy value as specified
// vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
request_total_slots
:
1024
,
kv_active_blocks
:
active_blocks_count
,
kv_active_blocks
:
active_blocks_count
,
kv_total_blocks
:
total_capacity
,
kv_total_blocks
:
total_capacity
,
num_requests_waiting
:
state
.waiting
.len
()
as
u64
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
:
0.0
,
// Placeholder value as specified
gpu_prefix_cache_hit_rate
,
}
}
// Guards drop naturally here in reverse order (LIFO): hit_rates_guard, kv_manager, state
}
}
}
}
/// Convert a Request to an ActiveSequence
/// Convert a Request to an ActiveSequence
fn
get_active_sequence
(
request
:
Request
,
block_size
:
u32
,
chunk_size
:
usize
)
->
ActiveSequence
{
fn
get_active_sequence
(
request
:
Request
,
block_size
:
usize
,
enable_prefix_caching
:
bool
,
)
->
ActiveSequence
{
if
let
Request
::
Active
(
active_seq
)
=
request
{
if
let
Request
::
Active
(
active_seq
)
=
request
{
return
active_seq
;
return
active_seq
;
}
}
...
@@ -419,7 +509,7 @@ fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) ->
...
@@ -419,7 +509,7 @@ fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) ->
direct_request
.tokens
,
direct_request
.tokens
,
direct_request
.max_output_tokens
,
direct_request
.max_output_tokens
,
Some
(
block_size
),
Some
(
block_size
),
Some
(
chunk_size
)
,
enable_prefix_caching
,
)
)
}
}
...
@@ -440,7 +530,7 @@ fn process_signals(
...
@@ -440,7 +530,7 @@ fn process_signals(
}
}
// Check we have a Use signal with blocks
// Check we have a Use signal with blocks
let
MoveBlock
::
Use
(
blocks
,
_
)
=
signal
else
{
let
MoveBlock
::
Use
(
blocks
)
=
signal
else
{
panic!
(
"Failed signal is Invalid. Has to fail on generation signal."
);
panic!
(
"Failed signal is Invalid. Has to fail on generation signal."
);
};
};
...
@@ -467,32 +557,37 @@ mod tests {
...
@@ -467,32 +557,37 @@ mod tests {
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
#[rstest]
#[rstest]
#[case::random(
false
)]
#[case::random_no_prefix_caching(
false
,
false
)]
#[case::caching(
true
)]
#[case::random_with_prefix_caching(
false
,
true
)]
#[case::caching_no_prefix_caching(
true
,
false
)]
#[case::caching_with_prefix_caching(
true
,
true
)]
#[tokio::test]
#[tokio::test]
async
fn
test_scheduler_token_generation_patterns
(
#[case]
use_shared_tokens
:
bool
)
{
async
fn
test_scheduler_token_generation_patterns
(
#[case]
use_shared_tokens
:
bool
,
#[case]
enable_prefix_caching
:
bool
,
)
{
std
::
env
::
set_var
(
"RUST_LOG"
,
"debug"
);
std
::
env
::
set_var
(
"RUST_LOG"
,
"debug"
);
let
kv_capacity
:
usize
=
500
;
let
kv_capacity
:
usize
=
500
;
let
watermark
:
f64
=
0.01
;
// 1% watermark
let
block_size
:
usize
=
64
;
let
block_size
:
u32
=
64
;
let
chunk_size
:
usize
=
256
;
let
num_requests
:
usize
=
100
;
let
num_requests
:
usize
=
100
;
let
input_len
:
usize
=
1000
;
let
input_len
:
usize
=
1000
;
let
max_output_tokens
:
usize
=
100
;
let
max_output_tokens
:
usize
=
100
;
// Create channel for token output
// Create channel for token output
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
channel
::
<
Uuid
>
(
1024
);
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
// Create scheduler with internal KvManager
// Create scheduler args using builder - now including enable_prefix_caching
let
scheduler
=
Scheduler
::
new
(
let
args
=
MockEngineArgs
::
builder
()
kv_capacity
,
.num_gpu_blocks
(
kv_capacity
)
watermark
,
.block_size
(
block_size
)
block_size
,
.speedup_ratio
(
10.0
)
Some
(
chunk_size
),
.enable_prefix_caching
(
enable_prefix_caching
)
Some
(
output_tx
),
.build
()
None
,
.unwrap
();
);
// Create scheduler with new args struct
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
// Create shared tokens for caching case
// Create shared tokens for caching case
let
shared_tokens
=
if
use_shared_tokens
{
let
shared_tokens
=
if
use_shared_tokens
{
...
@@ -523,6 +618,7 @@ mod tests {
...
@@ -523,6 +618,7 @@ mod tests {
tokens
:
input_tokens
,
tokens
:
input_tokens
,
max_output_tokens
,
max_output_tokens
,
uuid
:
None
,
uuid
:
None
,
dp_rank
:
None
,
};
};
scheduler
.receive
(
request
)
.await
;
scheduler
.receive
(
request
)
.await
;
}
}
...
@@ -547,7 +643,7 @@ mod tests {
...
@@ -547,7 +643,7 @@ mod tests {
// Manual debug ticker that prints forward pass metrics
// Manual debug ticker that prints forward pass metrics
_
=
debug_interval
.tick
()
=>
{
_
=
debug_interval
.tick
()
=>
{
let
_
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
let
_
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
//
println!("Forward Pass Metrics: {
:#?}",
_metrics);
println!
(
"Forward Pass Metrics: {_metrics
:#?}"
);
}
}
Some
(
_
)
=
output_rx
.recv
()
=>
{
Some
(
_
)
=
output_rx
.recv
()
=>
{
...
@@ -566,21 +662,177 @@ mod tests {
...
@@ -566,21 +662,177 @@ mod tests {
// Calculate and print elapsed time
// Calculate and print elapsed time
let
elapsed
=
start_time
.elapsed
();
let
elapsed
=
start_time
.elapsed
();
println!
(
println!
(
"Test completed in: {:?} for {} case"
,
"Test completed in: {:?} for {} case
with prefix_caching={}
"
,
elapsed
,
elapsed
,
if
use_shared_tokens
{
if
use_shared_tokens
{
"caching"
"caching"
}
else
{
}
else
{
"random"
"random"
}
},
enable_prefix_caching
);
);
// Assert that we received the expected number of tokens
// Assert that we received the expected number of tokens
assert
!
(
assert
!
(
received_tokens
>
expected_tokens
,
received_tokens
==
expected_tokens
,
"Received {} tokens but expected more than {}"
,
"Received {received_tokens} tokens but expected exactly {expected_tokens}"
received_tokens
,
);
expected_tokens
}
#[tokio::test]
async
fn
test_cache_hit_rate_with_identical_requests
()
{
let
block_size
:
usize
=
64
;
let
max_output_tokens
:
usize
=
10
;
let
speedup_ratio
=
10.0
;
let
num_requests
=
10
;
let
token_length
=
65
;
// Create channel for token output
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
// Create scheduler args
let
args
=
MockEngineArgs
::
builder
()
.num_gpu_blocks
(
100
)
// Large enough to not be a constraint
.block_size
(
block_size
)
.speedup_ratio
(
speedup_ratio
)
.build
()
.unwrap
();
// Create scheduler
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
// Create identical tokens for all requests
let
identical_tokens
:
Vec
<
u32
>
=
(
0
..
token_length
)
.map
(|
i
|
i
as
u32
)
.collect
();
// Send all requests with identical tokens
for
_
in
0
..
num_requests
{
let
request
=
DirectRequest
{
tokens
:
identical_tokens
.clone
(),
max_output_tokens
,
uuid
:
None
,
dp_rank
:
None
,
};
scheduler
.receive
(
request
)
.await
;
// Sleep for 0.1 second after each request
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
}
// Collect all generated tokens
let
mut
received_tokens
=
0
;
// Set up a timeout that resets to 0.5 seconds on each received token
let
timeout
=
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
500
));
tokio
::
pin!
(
timeout
);
// Set up debug ticker interval
let
mut
debug_interval
=
interval
(
Duration
::
from_millis
(
500
));
loop
{
tokio
::
select!
{
biased
;
// Manual debug ticker that prints forward pass metrics
_
=
debug_interval
.tick
()
=>
{
let
_
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
println!
(
"Forward Pass Metrics: {_metrics:#?}"
);
}
Some
(
_
signal
)
=
output_rx
.recv
()
=>
{
received_tokens
+=
1
;
// Reset timeout whenever we receive a token
timeout
.set
(
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
500
)));
}
_
=
&
mut
timeout
=>
{
// Break when timeout occurs (no more tokens for 0.5 seconds)
break
;
}
}
}
// Verify forward pass metrics
let
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
assert_eq!
(
metrics
.num_requests_waiting
,
0
,
"Expected no waiting requests, got {}"
,
metrics
.num_requests_waiting
);
assert
!
(
metrics
.gpu_prefix_cache_hit_rate
>
0.8
,
"Expected cache hit rate > 0.8, got {}"
,
metrics
.gpu_prefix_cache_hit_rate
);
println!
(
"Test passed! Cache hit rate: {:.3}"
,
metrics
.gpu_prefix_cache_hit_rate
);
println!
(
"Received {received_tokens} tokens"
);
}
#[tokio::test]
async
fn
test_receiver_drop_cleans_up_resources
()
{
let
block_size
:
usize
=
64
;
let
input_tokens
=
256
;
let
max_output_tokens
=
200
;
// More than we'll receive
// Create channel for token output
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
// Create scheduler args
let
args
=
MockEngineArgs
::
builder
()
.num_gpu_blocks
(
10
)
// Enough for 256 tokens (4 blocks)
.block_size
(
block_size
)
.speedup_ratio
(
100.0
)
// Fast simulation
.build
()
.unwrap
();
// Create scheduler
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
// Create request with 256 tokens
let
tokens
:
Vec
<
u32
>
=
(
0
..
input_tokens
)
.map
(|
i
|
i
as
u32
)
.collect
();
let
request
=
DirectRequest
{
tokens
,
max_output_tokens
,
uuid
:
None
,
dp_rank
:
None
,
};
scheduler
.receive
(
request
)
.await
;
// Receive exactly 129 tokens
let
mut
received_count
=
0
;
while
received_count
<
129
{
if
let
Some
(
_
signal
)
=
output_rx
.recv
()
.await
{
received_count
+=
1
;
}
else
{
panic!
(
"Channel closed before receiving 129 tokens"
);
}
}
// Drop the receiver immediately
drop
(
output_rx
);
// Wait for 1 second to allow cleanup
tokio
::
time
::
sleep
(
Duration
::
from_secs
(
1
))
.await
;
// Check forward pass metrics
let
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
assert_eq!
(
metrics
.gpu_cache_usage_perc
,
0.0
,
"Expected GPU cache usage to be 0%, got {}%"
,
metrics
.gpu_cache_usage_perc
*
100.0
);
assert_eq!
(
metrics
.kv_active_blocks
,
0
,
"Expected 0 active blocks, got {}"
,
metrics
.kv_active_blocks
);
);
}
}
}
}
lib/llm/src/mocker/sequence.rs
View file @
f0652d89
...
@@ -23,16 +23,23 @@ use uuid;
...
@@ -23,16 +23,23 @@ use uuid;
fn
create_unique_blocks_from_sequence
(
fn
create_unique_blocks_from_sequence
(
tokens
:
&
TokenBlockSequence
,
tokens
:
&
TokenBlockSequence
,
uuid
:
Option
<
uuid
::
Uuid
>
,
uuid
:
Option
<
uuid
::
Uuid
>
,
block_size
:
u32
,
block_size
:
usize
,
enable_prefix_caching
:
bool
,
)
->
Vec
<
UniqueBlock
>
{
)
->
Vec
<
UniqueBlock
>
{
let
mut
unique_blocks
:
Vec
<
UniqueBlock
>
=
tokens
let
mut
unique_blocks
:
Vec
<
UniqueBlock
>
=
tokens
.blocks
()
.blocks
()
.iter
()
.iter
()
.map
(|
block
|
UniqueBlock
::
FullBlock
(
block
.sequence_hash
()))
.map
(|
block
|
{
if
enable_prefix_caching
{
UniqueBlock
::
FullBlock
(
block
.sequence_hash
())
}
else
{
UniqueBlock
::
FullBlock
(
random
::
<
u64
>
())
}
})
.collect
();
.collect
();
// Only push the partial block if tokens count isn't a multiple of block_size
// Only push the partial block if tokens count isn't a multiple of block_size
if
tokens
.total_tokens
()
%
(
block_size
as
usize
)
!=
0
{
if
tokens
.total_tokens
()
%
block_size
!=
0
{
unique_blocks
.push
(
match
uuid
{
unique_blocks
.push
(
match
uuid
{
Some
(
uuid
)
=>
UniqueBlock
::
PartialBlock
(
uuid
),
Some
(
uuid
)
=>
UniqueBlock
::
PartialBlock
(
uuid
),
None
=>
UniqueBlock
::
default
(),
None
=>
UniqueBlock
::
default
(),
...
@@ -50,10 +57,7 @@ pub struct ActiveSequence {
...
@@ -50,10 +57,7 @@ pub struct ActiveSequence {
tokens
:
TokenBlockSequence
,
tokens
:
TokenBlockSequence
,
#[getter(copy)]
#[getter(copy)]
block_size
:
u32
,
block_size
:
usize
,
#[getter(copy)]
chunk_size
:
usize
,
// TODO: not actually used
#[getter(copy)]
#[getter(copy)]
max_output_tokens
:
usize
,
max_output_tokens
:
usize
,
...
@@ -61,10 +65,16 @@ pub struct ActiveSequence {
...
@@ -61,10 +65,16 @@ pub struct ActiveSequence {
#[getter(copy)]
#[getter(copy)]
generated_tokens
:
usize
,
generated_tokens
:
usize
,
#[getter(copy)]
already_generated_tokens
:
usize
,
#[getter(copy)]
#[getter(copy)]
num_input_tokens
:
usize
,
num_input_tokens
:
usize
,
creation_signal
:
Option
<
MoveBlock
>
,
creation_signal
:
Option
<
MoveBlock
>
,
#[getter(copy)]
enable_prefix_caching
:
bool
,
}
}
impl
ActiveSequence
{
impl
ActiveSequence
{
...
@@ -72,32 +82,33 @@ impl ActiveSequence {
...
@@ -72,32 +82,33 @@ impl ActiveSequence {
pub
fn
new
(
pub
fn
new
(
tokens
:
Vec
<
u32
>
,
tokens
:
Vec
<
u32
>
,
max_output_tokens
:
usize
,
max_output_tokens
:
usize
,
block_size
:
Option
<
u
32
>
,
block_size
:
Option
<
u
size
>
,
chunk_size
:
Option
<
usize
>
,
enable_prefix_caching
:
bool
,
)
->
Self
{
)
->
Self
{
let
block_size
=
block_size
.unwrap_or
(
64
);
let
block_size
=
block_size
.unwrap_or
(
64
);
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
let
chunk_size
=
chunk_size
.unwrap_or
(
256
);
let
num_input_tokens
=
tokens
.len
();
let
num_input_tokens
=
tokens
.len
();
let
tokens
=
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
,
None
);
let
tokens
=
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
let
unique_blocks
=
create_unique_blocks_from_sequence
(
&
tokens
,
None
,
block_size
);
let
unique_blocks
=
let
creation_signal
=
Some
(
MoveBlock
::
Use
(
unique_blocks
.clone
(),
None
));
create_unique_blocks_from_sequence
(
&
tokens
,
None
,
block_size
,
enable_prefix_caching
);
let
creation_signal
=
Some
(
MoveBlock
::
Use
(
unique_blocks
.clone
()));
Self
{
Self
{
unique_blocks
,
unique_blocks
,
tokens
,
tokens
,
block_size
,
block_size
,
chunk_size
,
max_output_tokens
,
max_output_tokens
,
generated_tokens
:
0
,
generated_tokens
:
0
,
already_generated_tokens
:
0
,
num_input_tokens
,
num_input_tokens
,
creation_signal
,
creation_signal
,
enable_prefix_caching
,
}
}
}
}
pub
fn
extra_tokens
(
&
self
)
->
u32
{
pub
fn
extra_tokens
(
&
self
)
->
u32
{
(
self
.len
()
%
self
.block_size
as
usize
)
as
u32
(
self
.len
()
%
self
.block_size
)
as
u32
}
}
pub
fn
len
(
&
self
)
->
usize
{
pub
fn
len
(
&
self
)
->
usize
{
...
@@ -112,20 +123,31 @@ impl ActiveSequence {
...
@@ -112,20 +123,31 @@ impl ActiveSequence {
pub
fn
new_with_signal
(
pub
fn
new_with_signal
(
tokens
:
Vec
<
u32
>
,
tokens
:
Vec
<
u32
>
,
max_output_tokens
:
usize
,
max_output_tokens
:
usize
,
block_size
:
Option
<
u
32
>
,
block_size
:
Option
<
u
size
>
,
chunk_size
:
Option
<
usize
>
,
enable_prefix_caching
:
bool
,
)
->
(
Self
,
Option
<
MoveBlock
>
)
{
)
->
(
Self
,
Option
<
MoveBlock
>
)
{
let
mut
sequence
=
Self
::
new
(
tokens
,
max_output_tokens
,
block_size
,
chunk_size
);
let
mut
sequence
=
Self
::
new
(
tokens
,
max_output_tokens
,
block_size
,
enable_prefix_caching
);
let
signal
=
sequence
.creation_signal
.take
();
let
signal
=
sequence
.creation_signal
.take
();
(
sequence
,
signal
)
(
sequence
,
signal
)
}
}
/// Get the parent hash from the second-to-last block if it exists and is a FullBlock
fn
get_parent_hash
(
&
self
)
->
Option
<
u64
>
{
if
self
.unique_blocks
.len
()
<
2
{
return
None
;
}
match
&
self
.unique_blocks
[
self
.unique_blocks
.len
()
-
2
]
{
UniqueBlock
::
FullBlock
(
hash
)
=>
Some
(
*
hash
),
_
=>
panic!
(
"Cannot have a partial block as parent"
),
}
}
/// Push a token to the sequence
/// Push a token to the sequence
pub
fn
push
(
&
mut
self
,
token
:
u32
)
->
Option
<
Vec
<
MoveBlock
>>
{
pub
fn
push
(
&
mut
self
,
token
:
u32
)
->
Option
<
Vec
<
MoveBlock
>>
{
self
.tokens
.append
(
token
)
.expect
(
"Token push failed."
);
self
.tokens
.append
(
token
)
.expect
(
"Token push failed."
);
self
.generated_tokens
+=
1
;
self
.generated_tokens
+=
1
;
if
self
.len
()
%
(
self
.block_size
as
usize
)
!=
1
{
if
self
.len
()
%
self
.block_size
!=
1
{
return
None
;
return
None
;
}
}
...
@@ -135,16 +157,24 @@ impl ActiveSequence {
...
@@ -135,16 +157,24 @@ impl ActiveSequence {
// Replace last partial block with full block if it exists
// Replace last partial block with full block if it exists
if
let
Some
(
UniqueBlock
::
PartialBlock
(
uuid
))
=
self
.unique_blocks
.last
()
.cloned
()
{
if
let
Some
(
UniqueBlock
::
PartialBlock
(
uuid
))
=
self
.unique_blocks
.last
()
.cloned
()
{
let
last_block_hash
=
self
.tokens
.last_complete_block
()
.unwrap
()
.sequence_hash
();
let
last_block_hash
=
if
self
.enable_prefix_caching
{
self
.tokens
.last_complete_block
()
.unwrap
()
.sequence_hash
()
}
else
{
random
::
<
u64
>
()
};
self
.unique_blocks
.pop
();
self
.unique_blocks
.pop
();
self
.unique_blocks
self
.unique_blocks
.push
(
UniqueBlock
::
FullBlock
(
last_block_hash
));
.push
(
UniqueBlock
::
FullBlock
(
last_block_hash
));
signals
.push
(
MoveBlock
::
Promote
(
uuid
,
last_block_hash
));
signals
.push
(
MoveBlock
::
Promote
(
uuid
,
last_block_hash
,
self
.get_parent_hash
(),
));
}
}
let
new_partial_block
=
UniqueBlock
::
default
();
let
new_partial_block
=
UniqueBlock
::
default
();
self
.unique_blocks
.push
(
new_partial_block
.clone
());
self
.unique_blocks
.push
(
new_partial_block
.clone
());
signals
.push
(
MoveBlock
::
Use
(
vec!
[
new_partial_block
]
,
None
));
signals
.push
(
MoveBlock
::
Use
(
vec!
[
new_partial_block
]));
Some
(
signals
)
Some
(
signals
)
}
}
...
@@ -204,15 +234,19 @@ impl ActiveSequence {
...
@@ -204,15 +234,19 @@ impl ActiveSequence {
}
}
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// maintaining the uuid of the last partial block
pub
fn
reset_with_signal
(
&
mut
self
)
->
Vec
<
MoveBlock
>
{
pub
fn
reset_with_signal
(
&
mut
self
)
->
Vec
<
MoveBlock
>
{
let
free_signal
=
self
.free_signal
();
let
free_signal
=
self
.free_signal
();
self
.tokens
.truncate
(
self
.num_input_tokens
)
.unwrap
();
self
.tokens
.truncate
(
self
.num_input_tokens
)
.unwrap
();
self
.unique_blocks
=
self
.unique_blocks
=
create_unique_blocks_from_sequence
(
create_unique_blocks_from_sequence
(
&
self
.tokens
,
None
,
self
.block_size
);
&
self
.tokens
,
None
,
self
.block_size
,
self
.enable_prefix_caching
,
);
self
.already_generated_tokens
=
self
.generated_tokens
.max
(
self
.already_generated_tokens
);
self
.generated_tokens
=
0
;
self
.generated_tokens
=
0
;
self
.creation_signal
=
Some
(
MoveBlock
::
Use
(
self
.unique_blocks
.clone
()
,
None
));
self
.creation_signal
=
Some
(
MoveBlock
::
Use
(
self
.unique_blocks
.clone
()));
free_signal
free_signal
}
}
...
@@ -223,7 +257,7 @@ impl ActiveSequence {
...
@@ -223,7 +257,7 @@ impl ActiveSequence {
self
.generated_tokens
=
self
.generated_tokens
.saturating_sub
(
1
);
self
.generated_tokens
=
self
.generated_tokens
.saturating_sub
(
1
);
// Reverts to the last full block
// Reverts to the last full block
if
self
.tokens
.total_tokens
()
%
(
self
.block_size
as
usize
)
==
0
{
if
self
.tokens
.total_tokens
()
%
self
.block_size
==
0
{
self
.unique_blocks
.pop
();
self
.unique_blocks
.pop
();
}
}
}
}
...
@@ -238,14 +272,14 @@ mod tests {
...
@@ -238,14 +272,14 @@ mod tests {
// Create a sequence with block size 16 initialized with tokens [0..15]
// Create a sequence with block size 16 initialized with tokens [0..15]
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
15
)
.collect
();
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
15
)
.collect
();
let
(
mut
seq1
,
signal1
)
=
let
(
mut
seq1
,
signal1
)
=
ActiveSequence
::
new_with_signal
(
initial_tokens
,
100
,
Some
(
16
),
Some
(
256
)
);
ActiveSequence
::
new_with_signal
(
initial_tokens
,
100
,
Some
(
16
),
true
);
assert_eq!
(
seq1
.num_input_tokens
(),
15
);
assert_eq!
(
seq1
.num_input_tokens
(),
15
);
assert_eq!
(
seq1
.len
(),
15
);
assert_eq!
(
seq1
.len
(),
15
);
// Check that we got a Use signal
// Check that we got a Use signal
assert
!
(
signal1
.is_some
());
assert
!
(
signal1
.is_some
());
match
&
signal1
{
match
&
signal1
{
Some
(
MoveBlock
::
Use
(
blocks
,
_
))
=>
{
Some
(
MoveBlock
::
Use
(
blocks
))
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert_eq!
(
blocks
.len
(),
1
);
}
}
_
=>
panic!
(
"Expected Use signal"
),
_
=>
panic!
(
"Expected Use signal"
),
...
@@ -264,33 +298,31 @@ mod tests {
...
@@ -264,33 +298,31 @@ mod tests {
let
signal_16
=
signal_16
.unwrap
();
let
signal_16
=
signal_16
.unwrap
();
assert_eq!
(
signal_16
.len
(),
2
);
assert_eq!
(
signal_16
.len
(),
2
);
// First signal should be Promote for the previous block
match
&
signal_16
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
assert_eq!
(
*
parent_hash
,
None
);
}
_
=>
panic!
(
"Expected Promote signal as second signal"
),
}
// Second signal should be Use for new partial block
// Second signal should be Use for new partial block
match
&
signal_16
[
1
]
{
match
&
signal_16
[
1
]
{
MoveBlock
::
Use
(
blocks
,
_
)
=>
{
MoveBlock
::
Use
(
blocks
)
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
}
_
=>
panic!
(
"Expected Use signal as first signal"
),
_
=>
panic!
(
"Expected Use signal as first signal"
),
}
}
// First signal should be Promote for the previous block
match
&
signal_16
[
0
]
{
MoveBlock
::
Promote
(
uuid
,
_
)
=>
{
// The uuid is generated dynamically, so we just check it exists
let
_
=
uuid
;
}
_
=>
panic!
(
"Expected Promote signal as second signal"
),
}
// Verify state after pushing tokens
// Verify state after pushing tokens
assert_eq!
(
seq1
.unique_blocks
()
.len
(),
2
);
// One full block and one partial block
assert_eq!
(
seq1
.unique_blocks
()
.len
(),
2
);
// One full block and one partial block
assert_eq!
(
seq1
.len
(),
17
);
assert_eq!
(
seq1
.len
(),
17
);
assert_eq!
(
seq1
.len
()
%
(
seq1
.block_size
()
as
usize
)
,
1
);
assert_eq!
(
seq1
.len
()
%
seq1
.block_size
(),
1
);
// Create another sequence with block size 16 initialized with tokens [0..17]
// Create another sequence with block size 16 initialized with tokens [0..17]
let
extended_tokens
:
Vec
<
u32
>
=
(
0
..
16
)
.collect
();
let
extended_tokens
:
Vec
<
u32
>
=
(
0
..
16
)
.collect
();
let
(
mut
seq2
,
_
)
=
let
(
mut
seq2
,
_
)
=
ActiveSequence
::
new_with_signal
(
extended_tokens
,
100
,
Some
(
16
),
true
);
ActiveSequence
::
new_with_signal
(
extended_tokens
,
100
,
Some
(
16
),
Some
(
256
));
seq2
.push
(
16
);
seq2
.push
(
16
);
seq2
.pop
();
seq2
.pop
();
seq2
.push
(
16
);
seq2
.push
(
16
);
...
@@ -335,12 +367,12 @@ mod tests {
...
@@ -335,12 +367,12 @@ mod tests {
"seq2 should have exactly 3 blocks"
"seq2 should have exactly 3 blocks"
);
);
assert_eq!
(
assert_eq!
(
seq1
.len
()
%
(
seq1
.block_size
()
as
usize
)
,
seq1
.len
()
%
seq1
.block_size
(),
1
,
1
,
"seq1 should have 1 partial token"
"seq1 should have 1 partial token"
);
);
assert_eq!
(
assert_eq!
(
seq2
.len
()
%
(
seq2
.block_size
()
as
usize
)
,
seq2
.len
()
%
seq2
.block_size
(),
1
,
1
,
"seq2 should have 1 partial token"
"seq2 should have 1 partial token"
);
);
...
@@ -352,9 +384,38 @@ mod tests {
...
@@ -352,9 +384,38 @@ mod tests {
"First two blocks should be identical"
"First two blocks should be identical"
);
);
// Push tokens 34..47 to seq1
for
token
in
33
..
48
{
seq1
.push
(
token
);
}
// Push token 48 and get the signal - this completes the block and triggers signals
let
signal
=
seq1
.push
(
48
);
let
signal
=
signal
.unwrap
();
// Check that signal[0] is promote
match
&
signal
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
// Check that the parent_hash matches unique_blocks[1], which should be a full block
if
let
UniqueBlock
::
FullBlock
(
expected_hash
)
=
seq1
.unique_blocks
()[
1
]
{
assert_eq!
(
*
parent_hash
,
Some
(
expected_hash
),
"Parent hash should match unique_blocks[1]"
);
}
else
{
panic!
(
"unique_blocks[1] should be a full block"
);
}
}
_
=>
panic!
(
"Expected Promote signal as first signal"
),
}
// Reset seq1 and check that it equals the original clone
// Reset seq1 and check that it equals the original clone
let
free_signals
=
seq1
.reset_with_signal
();
let
free_signals
=
seq1
.reset_with_signal
();
// 49 - 15 generated tokens
assert_eq!
(
seq1
.already_generated_tokens
,
34
);
// Verify the reset signals include proper cleanup events
// Verify the reset signals include proper cleanup events
assert
!
(
!
free_signals
.is_empty
());
assert
!
(
!
free_signals
.is_empty
());
}
}
...
@@ -363,13 +424,12 @@ mod tests {
...
@@ -363,13 +424,12 @@ mod tests {
fn
test_active_sequence_generate_signals
()
{
fn
test_active_sequence_generate_signals
()
{
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
14
)
.collect
();
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
14
)
.collect
();
let
(
mut
seq
,
signal
)
=
let
(
mut
seq
,
signal
)
=
ActiveSequence
::
new_with_signal
(
initial_tokens
,
5
,
Some
(
16
),
true
);
ActiveSequence
::
new_with_signal
(
initial_tokens
,
5
,
Some
(
16
),
Some
(
256
));
// Initial signal - should have received a Use signal for the partial block
// Initial signal - should have received a Use signal for the partial block
assert
!
(
signal
.is_some
());
assert
!
(
signal
.is_some
());
match
signal
{
match
signal
{
Some
(
MoveBlock
::
Use
(
blocks
,
_
))
=>
{
Some
(
MoveBlock
::
Use
(
blocks
))
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
}
...
@@ -385,25 +445,23 @@ mod tests {
...
@@ -385,25 +445,23 @@ mod tests {
let
signals_second
=
seq
.generate
();
let
signals_second
=
seq
.generate
();
assert_eq!
(
signals_second
.len
(),
2
);
assert_eq!
(
signals_second
.len
(),
2
);
// First signal should be Use for new partial block
// First signal should be Promote
match
&
signals_second
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
assert_eq!
(
*
parent_hash
,
None
);
}
_
=>
panic!
(
"Expected Promote signal as first signal after second token"
),
}
// Second signal should be Use for new partial block
match
&
signals_second
[
1
]
{
match
&
signals_second
[
1
]
{
MoveBlock
::
Use
(
blocks
,
_
)
=>
{
MoveBlock
::
Use
(
blocks
)
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
}
_
=>
panic!
(
"Expected Use signal as second signal after second token"
),
_
=>
panic!
(
"Expected Use signal as second signal after second token"
),
}
}
// Second signal should be Promote
match
&
signals_second
[
0
]
{
MoveBlock
::
Promote
(
uuid
,
hash
)
=>
{
// The uuid and hash values are generated dynamically, so we just check the event type
let
_
=
uuid
;
let
_
=
hash
;
}
_
=>
panic!
(
"Expected Promote signal as first signal after second token"
),
}
// Generate fourth token - should not trigger new signals as it's adding to partial block
// Generate fourth token - should not trigger new signals as it's adding to partial block
let
signals_third
=
seq
.generate
();
let
signals_third
=
seq
.generate
();
assert_eq!
(
signals_third
.len
(),
0
);
assert_eq!
(
signals_third
.len
(),
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment