Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f0652d89
Unverified
Commit
f0652d89
authored
Jul 01, 2025
by
Yan Ru Pei
Committed by
GitHub
Jul 01, 2025
Browse files
feat: vllm mocker enhancement (#1236)
parent
0d6cae85
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1654 additions
and
410 deletions
+1654
-410
lib/llm/src/mocker.rs
lib/llm/src/mocker.rs
+1
-0
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+764
-0
lib/llm/src/mocker/evictor.rs
lib/llm/src/mocker/evictor.rs
+96
-105
lib/llm/src/mocker/kv_manager.rs
lib/llm/src/mocker/kv_manager.rs
+166
-70
lib/llm/src/mocker/protocols.rs
lib/llm/src/mocker/protocols.rs
+86
-4
lib/llm/src/mocker/scheduler.rs
lib/llm/src/mocker/scheduler.rs
+424
-172
lib/llm/src/mocker/sequence.rs
lib/llm/src/mocker/sequence.rs
+117
-59
No files found.
lib/llm/src/mocker.rs
View file @
f0652d89
...
...
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub
mod
engine
;
pub
mod
evictor
;
pub
mod
kv_manager
;
pub
mod
protocols
;
...
...
lib/llm/src/mocker/engine.rs
0 → 100644
View file @
f0652d89
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MockSchedulerEngine - AsyncEngine wrapper around the Scheduler
//!
//! This module provides an AsyncEngine implementation that wraps the Scheduler
//! to provide streaming token generation with realistic timing simulation.
use
crate
::
kv_router
::
publisher
::
WorkerMetricsPublisher
;
use
crate
::
mocker
::
protocols
::
DirectRequest
;
use
crate
::
mocker
::
protocols
::{
MockEngineArgs
,
OutputSignal
};
use
crate
::
mocker
::
scheduler
::
Scheduler
;
use
crate
::
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
};
use
crate
::
protocols
::
TokenIdType
;
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
dynamo_runtime
::
DistributedRuntime
;
use
tokio_util
::
sync
::
CancellationToken
;
use
dynamo_runtime
::{
component
::
Component
,
engine
::
AsyncEngineContextProvider
,
pipeline
::{
async_trait
,
AsyncEngine
,
Error
,
ManyOut
,
ResponseStream
,
SingleIn
},
traits
::
DistributedRuntimeProvider
,
Result
,
};
use
crate
::
kv_router
::
protocols
::{
KvCacheEvent
,
KvCacheEventData
};
use
crate
::
kv_router
::
publisher
::
KvEventPublisher
;
use
futures
::
StreamExt
;
use
rand
::
Rng
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
tokio
::
sync
::{
mpsc
,
Mutex
,
OnceCell
};
use
tokio
::
time
::{
interval
,
Duration
};
use
tokio_stream
::
wrappers
::
ReceiverStream
;
use
uuid
::
Uuid
;
pub
const
MOCKER_COMPONENT
:
&
str
=
"mocker"
;
/// Generate a random token ID from 1k to 5k
fn
generate_random_token
()
->
TokenIdType
{
let
mut
rng
=
rand
::
rng
();
rng
.random_range
(
1000
..
5000
)
}
/// AsyncEngine wrapper around the Scheduler that generates random character tokens
#[derive(Clone)]
pub
struct
MockVllmEngine
{
active_requests
:
Arc
<
Mutex
<
HashMap
<
Uuid
,
mpsc
::
UnboundedSender
<
OutputSignal
>>>>
,
request_senders
:
Arc
<
OnceCell
<
Vec
<
mpsc
::
UnboundedSender
<
DirectRequest
>>>>
,
engine_args
:
MockEngineArgs
,
}
impl
MockVllmEngine
{
/// Create a new MockVllmEngine with the given parameters
pub
fn
new
(
args
:
MockEngineArgs
)
->
Self
{
Self
{
active_requests
:
Arc
::
new
(
Mutex
::
new
(
HashMap
::
new
())),
request_senders
:
Arc
::
new
(
OnceCell
::
new
()),
engine_args
:
args
,
}
}
pub
async
fn
start
(
&
self
,
component
:
Component
)
->
Result
<
()
>
{
let
cancel_token
=
component
.drt
()
.runtime
()
.child_token
();
let
(
schedulers
,
kv_event_receiver
)
=
self
.start_schedulers
(
self
.engine_args
.clone
(),
self
.active_requests
.clone
(),
cancel_token
.clone
(),
);
Self
::
start_metrics_publishing
(
&
schedulers
,
Some
(
component
.clone
()),
cancel_token
.clone
())
.await
?
;
// Start KV events publishing with the actual receivers from schedulers
if
self
.engine_args.enable_prefix_caching
{
Self
::
start_kv_events_publishing
(
kv_event_receiver
,
Some
(
component
.clone
()),
self
.engine_args.block_size
,
cancel_token
.clone
(),
)
.await
?
;
}
Ok
(())
}
pub
fn
direct
(
&
self
,
request
:
DirectRequest
,
dp_rank
:
usize
)
{
let
senders
=
self
.request_senders
.get
()
.expect
(
"Not initialized"
);
let
_
=
senders
[
dp_rank
]
.send
(
request
);
}
/// Create schedulers and spawn their background tasks for distributing token notifications
/// Returns schedulers and their corresponding KV event receivers
fn
start_schedulers
(
&
self
,
args
:
MockEngineArgs
,
active_requests
:
Arc
<
Mutex
<
HashMap
<
Uuid
,
mpsc
::
UnboundedSender
<
OutputSignal
>>>>
,
cancel_token
:
CancellationToken
,
)
->
(
Vec
<
Scheduler
>
,
Vec
<
mpsc
::
UnboundedReceiver
<
KvCacheEventData
>>
,
)
{
let
mut
schedulers
=
Vec
::
<
Scheduler
>
::
new
();
let
mut
kv_event_receivers
=
Vec
::
new
();
let
mut
senders
=
Vec
::
with_capacity
(
args
.dp_size
as
usize
);
// Create multiple schedulers and their background tasks
for
dp_rank
in
0
..
args
.dp_size
{
// Create a shared output channel that this scheduler will use
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
// Create a channel for KV events from this scheduler
let
(
kv_events_tx
,
kv_events_rx
)
=
mpsc
::
unbounded_channel
::
<
KvCacheEventData
>
();
let
scheduler
=
Scheduler
::
new
(
args
.clone
(),
Some
(
dp_rank
),
Some
(
output_tx
),
Some
(
kv_events_tx
),
// Pass the KV events sender to scheduler
Some
(
cancel_token
.clone
()),
);
senders
.push
(
scheduler
.request_sender
());
schedulers
.push
(
scheduler
);
kv_event_receivers
.push
(
kv_events_rx
);
// Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx));
let
active_requests_clone
=
active_requests
.clone
();
let
cancel_token_cloned
=
cancel_token
.clone
();
tokio
::
spawn
(
async
move
{
loop
{
tokio
::
select!
{
signal_result
=
output_rx
.recv
()
=>
{
let
Some
(
signal
)
=
signal_result
else
{
break
;
// Channel closed
};
// Notify the specific request that a token was generated
let
active
=
active_requests_clone
.lock
()
.await
;
if
let
Some
(
request_tx
)
=
active
.get
(
&
signal
.uuid
)
{
let
_
=
request_tx
.send
(
signal
);
}
}
_
=
cancel_token_cloned
.cancelled
()
=>
{
break
;
}
}
}
});
}
// Set the senders once
self
.request_senders
.set
(
senders
)
.expect
(
"Already initialized"
);
(
schedulers
,
kv_event_receivers
)
}
/// Start background tasks to poll and publish metrics every second
async
fn
start_metrics_publishing
(
schedulers
:
&
[
Scheduler
],
component
:
Option
<
Component
>
,
cancel_token
:
CancellationToken
,
)
->
Result
<
()
>
{
tracing
::
info!
(
"Creating metrics publisher"
);
let
metrics_publisher
=
Arc
::
new
(
WorkerMetricsPublisher
::
new
()
?
);
tracing
::
info!
(
"Metrics publisher created"
);
if
let
Some
(
comp
)
=
component
{
tracing
::
info!
(
"Creating metrics endpoint"
);
tokio
::
spawn
({
let
publisher
=
metrics_publisher
.clone
();
async
move
{
if
let
Err
(
e
)
=
publisher
.create_endpoint
(
comp
.clone
())
.await
{
tracing
::
error!
(
"Metrics endpoint failed: {e}"
);
}
}
});
// Give it a moment to start
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
tracing
::
info!
(
"Metrics endpoint started (background)"
);
}
tracing
::
info!
(
"Starting metrics background tasks"
);
for
(
dp_rank
,
scheduler
)
in
schedulers
.iter
()
.enumerate
()
{
let
scheduler
=
scheduler
.clone
();
let
publisher
=
metrics_publisher
.clone
();
let
dp_rank
=
dp_rank
as
u32
;
let
cancel_token
=
cancel_token
.clone
();
tokio
::
spawn
(
async
move
{
let
mut
interval
=
interval
(
Duration
::
from_millis
(
100
));
loop
{
tokio
::
select!
{
_
=
interval
.tick
()
=>
{
// Get metrics from scheduler
let
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
// Publish metrics
if
let
Err
(
e
)
=
publisher
.publish
(
Arc
::
new
(
metrics
))
{
tracing
::
warn!
(
"Failed to publish metrics for DP rank {dp_rank}: {e}"
);
}
else
{
tracing
::
trace!
(
"Published metrics for DP rank {}"
,
dp_rank
);
}
}
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
info!
(
"Metrics publishing cancelled for DP rank {dp_rank}"
);
break
;
}
}
}
});
}
tracing
::
info!
(
"Metrics background tasks started"
);
Ok
(())
}
/// Start background tasks to collect and publish KV events from schedulers
async
fn
start_kv_events_publishing
(
kv_event_receivers
:
Vec
<
mpsc
::
UnboundedReceiver
<
KvCacheEventData
>>
,
component
:
Option
<
Component
>
,
block_size
:
usize
,
cancel_token
:
CancellationToken
,
)
->
Result
<
()
>
{
tracing
::
info!
(
"Starting KV events publishing"
);
// Only start KV events publishing if we have a component
let
Some
(
comp
)
=
component
else
{
tracing
::
warn!
(
"No component provided, skipping KV events publishing"
);
return
Ok
(());
};
tracing
::
info!
(
"Component found for KV events publishing"
);
tracing
::
debug!
(
"Getting worker_id"
);
let
worker_id
=
comp
.drt
()
.primary_lease
()
.expect
(
"Cannot publish KV events without lease"
)
// ← This will PANIC on static!
.id
();
// let worker_id = 0;
tracing
::
debug!
(
"Worker_id set to: {worker_id}"
);
tracing
::
info!
(
"Creating KV event publisher"
);
let
kv_event_publisher
=
Arc
::
new
(
KvEventPublisher
::
new
(
comp
.clone
(),
worker_id
,
block_size
as
u32
,
None
,
)
?
);
tracing
::
info!
(
"KV event publisher created"
);
tracing
::
info!
(
"Starting KV event background tasks for {} receivers"
,
kv_event_receivers
.len
()
);
for
(
dp_rank
,
mut
kv_events_rx
)
in
kv_event_receivers
.into_iter
()
.enumerate
()
{
tracing
::
debug!
(
"Starting background task for DP rank {dp_rank}"
);
let
publisher
=
kv_event_publisher
.clone
();
let
dp_rank
=
dp_rank
as
u32
;
let
cancel_token
=
cancel_token
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
debug!
(
"Background task started for DP rank {dp_rank}"
);
loop
{
tokio
::
select!
{
// Receive actual KV events from the scheduler
Some
(
event_data
)
=
kv_events_rx
.recv
()
=>
{
// Convert KvCacheEventData to KvCacheEvent with random UUID as event_id
let
event
=
KvCacheEvent
{
event_id
:
Uuid
::
new_v4
()
.as_u128
()
as
u64
,
data
:
event_data
,
};
// Publish the event
if
let
Err
(
e
)
=
publisher
.publish
(
event
)
{
tracing
::
warn!
(
"Failed to publish KV event for DP rank {dp_rank}: {e}"
);
}
else
{
tracing
::
trace!
(
"Published KV event for DP rank {dp_rank}"
);
}
}
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
info!
(
"KV events publishing cancelled for DP rank {dp_rank}"
);
break
;
}
}
}
});
}
tracing
::
info!
(
"All KV event background tasks started"
);
Ok
(())
}
}
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
LLMEngineOutput
>
,
Error
>
for
MockVllmEngine
{
async
fn
generate
(
&
self
,
input
:
SingleIn
<
PreprocessedRequest
>
,
)
->
Result
<
ManyOut
<
LLMEngineOutput
>
,
Error
>
{
let
(
request
,
ctx
)
=
input
.into_parts
();
// Extract dp_rank from annotations if present
let
dp_rank
=
request
.annotations
.iter
()
.find_map
(|
ann
|
{
if
ann
.starts_with
(
"dp_rank:"
)
{
ann
.strip_prefix
(
"dp_rank:"
)
.and_then
(|
s
|
s
.parse
()
.ok
())
}
else
{
None
}
})
.unwrap_or
(
0
);
// Validate dp_rank
if
dp_rank
>=
self
.engine_args.dp_size
{
return
Err
(
Error
::
msg
(
format!
(
"dp_rank {} is out of bounds for dp_size {}"
,
dp_rank
,
self
.engine_args.dp_size
)));
}
let
request_uuid
=
ctx
.id
()
.parse
()
.unwrap_or
(
Uuid
::
new_v4
());
// Convert PreprocessedRequest to DirectRequest for scheduler
let
direct_request
=
DirectRequest
{
tokens
:
request
.token_ids
.clone
(),
max_output_tokens
:
request
.stop_conditions
.max_tokens
.expect
(
"max_output_tokens must be specified for mocker"
)
as
usize
,
uuid
:
Some
(
request_uuid
),
dp_rank
:
Some
(
dp_rank
),
};
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
{
let
mut
active
=
self
.active_requests
.lock
()
.await
;
active
.insert
(
request_uuid
,
request_tx
);
}
// Send the request to the appropriate scheduler based on dp_rank
self
.direct
(
direct_request
,
dp_rank
as
usize
);
// Create a simple channel for the stream
let
(
stream_tx
,
stream_rx
)
=
mpsc
::
channel
::
<
LLMEngineOutput
>
(
64
);
let
active_requests
=
self
.active_requests
.clone
();
let
async_context
=
ctx
.context
();
let
max_tokens
=
request
.stop_conditions.max_tokens
.unwrap_or
(
100
)
as
usize
;
// Spawn a task to handle the complex async logic
tokio
::
spawn
(
async
move
{
let
mut
token_count
=
0
;
loop
{
tokio
::
select!
{
maybe_signal
=
request_rx
.recv
()
=>
{
let
Some
(
signal
)
=
maybe_signal
else
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
error
(
"All output transmitters closed"
.to_string
()))
.await
;
break
;
};
if
signal
.completed
&&
token_count
<
max_tokens
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
error
(
"Completion signal received before max tokens reached"
.to_string
()))
.await
;
break
;
}
if
signal
.completed
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
length
())
.await
;
break
;
}
// Generate a new token
let
token_id
=
generate_random_token
();
token_count
+=
1
;
let
output
=
LLMEngineOutput
{
token_ids
:
vec!
[
token_id
],
tokens
:
None
,
// Let backend handle detokenization
text
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
finish_reason
:
None
,
index
:
None
,
};
if
stream_tx
.send
(
output
)
.await
.is_err
()
{
break
;
}
}
_
=
async_context
.stopped
()
=>
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
cancelled
())
.await
;
break
;
}
}
}
// Clean up: remove this request from active requests
let
mut
active
=
active_requests
.lock
()
.await
;
active
.remove
(
&
request_uuid
);
});
// Create a simple ReceiverStream which is naturally Send + Sync
let
stream
=
ReceiverStream
::
new
(
stream_rx
);
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
.context
()))
}
}
pub
struct
AnnotatedMockEngine
{
inner
:
Arc
<
MockVllmEngine
>
,
}
impl
AnnotatedMockEngine
{
pub
fn
new
(
inner
:
MockVllmEngine
,
distributed_runtime
:
DistributedRuntime
,
endpoint
:
dynamo_runtime
::
protocols
::
Endpoint
,
)
->
Self
{
let
inner
=
Arc
::
new
(
inner
);
let
inner_clone
=
inner
.clone
();
// Start background task to wait for component service and start the engine
tokio
::
spawn
(
async
move
{
loop
{
// Try to create component
let
Ok
(
namespace
)
=
distributed_runtime
.namespace
(
&
endpoint
.namespace
)
else
{
tracing
::
debug!
(
"Namespace not available yet, retrying..."
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
continue
;
};
let
Ok
(
component
)
=
namespace
.component
(
&
endpoint
.component
)
else
{
tracing
::
debug!
(
"Component not available yet, retrying..."
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
continue
;
};
// Check if service is available by trying to list instances
let
Ok
(
instances
)
=
component
.list_instances
()
.await
else
{
tracing
::
debug!
(
"Cannot list instances yet, retrying..."
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
continue
;
};
if
instances
.is_empty
()
{
tracing
::
debug!
(
"No instances available yet, retrying..."
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
continue
;
}
tracing
::
info!
(
"Component service is now available, starting mocker engine"
);
// Start the engine with the component
if
let
Err
(
e
)
=
inner_clone
.start
(
component
)
.await
{
tracing
::
error!
(
"Failed to start mocker engine: {e}"
);
}
break
;
}
});
Self
{
inner
}
}
}
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
Error
>
for
AnnotatedMockEngine
{
async
fn
generate
(
&
self
,
input
:
SingleIn
<
PreprocessedRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
Error
>
{
let
stream
=
self
.inner
.generate
(
input
)
.await
?
;
let
context
=
stream
.context
();
// Convert stream of LLMEngineOutput to Annotated<LLMEngineOutput>
let
annotated_stream
=
stream
.map
(
Annotated
::
from_data
);
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
annotated_stream
),
context
))
}
}
/// Create a mocker engine as ExecutionContext
pub
async
fn
make_mocker_engine
(
distributed_runtime
:
DistributedRuntime
,
endpoint
:
dynamo_runtime
::
protocols
::
Endpoint
,
args
:
MockEngineArgs
,
)
->
Result
<
crate
::
backend
::
ExecutionContext
,
Error
>
{
// Create the mocker engine
tracing
::
info!
(
"Creating mocker engine (service will be started in background)"
);
let
annotated_engine
=
AnnotatedMockEngine
::
new
(
MockVllmEngine
::
new
(
args
),
distributed_runtime
,
endpoint
);
Ok
(
Arc
::
new
(
annotated_engine
))
}
#[cfg(test)]
mod
integration_tests
{
use
super
::
*
;
use
crate
::
kv_router
::
indexer
::
RouterEvent
;
use
crate
::
kv_router
::
KV_EVENT_SUBJECT
;
use
crate
::
protocols
::
common
::{
SamplingOptions
,
StopConditions
};
use
dynamo_runtime
::{
pipeline
::
Context
,
pipeline
::{
network
::
Ingress
,
PushRouter
},
traits
::
events
::
EventSubscriber
,
DistributedRuntime
,
Worker
,
};
use
futures
::
StreamExt
;
use
tokio
::
time
::
timeout
;
#[tokio::test]
#[ignore]
// Run with: cargo test -- --ignored
async
fn
test_mock_vllm_engine_full_integration
()
->
Result
<
()
>
{
const
DP_SIZE
:
u32
=
2
;
const
TOKENS_PER_REQUEST
:
usize
=
20
;
const
BLOCK_SIZE
:
usize
=
2
;
// Create runtime and distributed runtime
let
worker
=
Worker
::
from_settings
()
?
;
let
runtime
=
worker
.runtime
();
let
distributed
=
DistributedRuntime
::
from_settings
(
runtime
.clone
())
.await
?
;
tracing
::
info!
(
"✓ Runtime and distributed runtime created"
);
// Create component for MockVllmEngine (needed for publishers)
let
test_component
=
distributed
.namespace
(
"test"
)
?
.component
(
MOCKER_COMPONENT
)
?
.service_builder
()
.create
()
.await
?
;
tracing
::
info!
(
"✓ Test component created"
);
// Create MockVllmEngine WITH component (enables publishers)
let
args
=
MockEngineArgs
::
builder
()
.speedup_ratio
(
10.0
)
.dp_size
(
DP_SIZE
)
.block_size
(
BLOCK_SIZE
)
.build
()
.unwrap
();
let
engine
=
MockVllmEngine
::
new
(
args
);
engine
.start
(
test_component
.clone
())
.await
?
;
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
500
))
.await
;
let
engine
=
Arc
::
new
(
engine
);
tracing
::
info!
(
"✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}"
);
// Set up KV events subscriber
let
mut
kv_events_subscriber
=
test_component
.subscribe
(
KV_EVENT_SUBJECT
)
.await
?
;
tracing
::
info!
(
"✓ KV events subscriber created"
);
// Wrap with Ingress and register with component/endpoint
let
ingress
=
Ingress
::
for_engine
(
engine
)
?
;
tracing
::
info!
(
"✓ Ingress wrapper created"
);
// Start the server in background
let
server_handle
=
tokio
::
spawn
({
let
test_component
=
test_component
.clone
();
async
move
{
if
let
Err
(
e
)
=
test_component
.endpoint
(
"generate"
)
.endpoint_builder
()
.handler
(
ingress
)
.start
()
.await
{
eprintln!
(
"❌ Generate endpoint failed: {e}"
);
}
}
});
tracing
::
info!
(
"✓ Server started in background"
);
// Give server time to start
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
500
))
.await
;
tracing
::
info!
(
"✓ Server startup delay completed"
);
// Print all registered instances from etcd
match
test_component
.list_instances
()
.await
{
Ok
(
instances
)
=>
{
tracing
::
info!
(
"📋 Found {} registered instances:"
,
instances
.len
());
for
instance
in
instances
{
tracing
::
info!
(
" • {}/{}/{} (ID: {})"
,
instance
.namespace
,
instance
.component
,
instance
.endpoint
,
instance
.instance_id
);
}
}
Err
(
e
)
=>
{
tracing
::
error!
(
"❌ Failed to list instances: {e}"
);
}
}
// Create client
let
client
=
distributed
.namespace
(
"test"
)
?
.component
(
MOCKER_COMPONENT
)
?
.endpoint
(
"generate"
)
.client
()
.await
?
;
tracing
::
info!
(
"✓ Client created"
);
let
router
=
PushRouter
::
from_client
(
client
,
Default
::
default
())
.await
?
;
tracing
::
info!
(
"✓ Router created"
);
// Create test requests for both DP workers
let
create_request
=
|
tokens
:
Vec
<
TokenIdType
>
,
dp_rank
:
u32
|
PreprocessedRequest
{
token_ids
:
tokens
,
batch_token_ids
:
None
,
stop_conditions
:
StopConditions
{
max_tokens
:
Some
(
TOKENS_PER_REQUEST
as
u32
),
..
Default
::
default
()
},
sampling_options
:
SamplingOptions
::
default
(),
eos_token_ids
:
vec!
[],
mdc_sum
:
None
,
annotations
:
vec!
[
format!
(
"dp_rank:{dp_rank}"
)],
estimated_prefix_hit_num_blocks
:
None
,
};
let
requests
=
vec!
[
create_request
(
vec!
[
1
,
2
,
3
,
4
,
5
],
0
),
create_request
(
vec!
[
1
,
2
,
3
,
4
,
5
],
0
),
create_request
(
vec!
[
1
,
2
,
3
,
4
,
5
],
1
),
create_request
(
vec!
[
1
,
2
,
3
,
4
,
5
],
1
),
];
tracing
::
info!
(
"✓ Test requests created ({} requests total)"
,
requests
.len
()
);
// Test each request
for
(
i
,
request
)
in
requests
.into_iter
()
.enumerate
()
{
tracing
::
info!
(
"Testing request {}"
,
i
+
1
);
let
response_stream
=
router
.generate
(
Context
::
new
(
request
))
.await
?
;
let
responses
:
Vec
<
LLMEngineOutput
>
=
response_stream
.collect
()
.await
;
// Should have at least one response
assert
!
(
!
responses
.is_empty
(),
"Request {} should produce at least one response"
,
i
+
1
);
// Count total tokens generated (excluding final message)
let
mut
total_tokens
=
0
;
let
mut
has_finish_reason
=
false
;
for
response
in
&
responses
{
total_tokens
+=
response
.token_ids
.len
();
if
response
.finish_reason
.is_some
()
{
has_finish_reason
=
true
;
}
}
// Should have a finish reason in the last response
assert
!
(
has_finish_reason
,
"Request {} should have a finish reason"
,
i
+
1
);
// Verify we got approximately the expected number of tokens
assert
!
(
total_tokens
<=
TOKENS_PER_REQUEST
+
1
,
// +1 for potential final empty response
"Request {} generated {} tokens, expected at most {}"
,
i
+
1
,
total_tokens
,
TOKENS_PER_REQUEST
+
1
);
tracing
::
info!
(
"✓ Request {} completed successfully with {} tokens"
,
i
+
1
,
total_tokens
);
}
tracing
::
info!
(
"🎉 All requests completed successfully!"
);
// Try to receive at least one KV event with 100ms timeout
tracing
::
info!
(
"Waiting for KV event with 100ms timeout..."
);
let
msg
=
timeout
(
Duration
::
from_millis
(
100
),
kv_events_subscriber
.next
())
.await
.map_err
(|
_
|
Error
::
msg
(
"Timeout waiting for KV event"
))
?
.ok_or_else
(||
Error
::
msg
(
"KV events stream ended unexpectedly"
))
?
;
match
serde_json
::
from_slice
::
<
RouterEvent
>
(
&
msg
.payload
)
{
Ok
(
event
)
=>
{
tracing
::
info!
(
"✓ Received KV event: {event:?}"
);
}
Err
(
e
)
=>
{
return
Err
(
Error
::
msg
(
format!
(
"Failed to deserialize KV event: {e}"
)));
}
}
// Use KvMetricsAggregator to get metrics more easily
let
cancel_token
=
test_component
.drt
()
.runtime
()
.child_token
();
let
metrics_aggregator
=
crate
::
kv_router
::
metrics_aggregator
::
KvMetricsAggregator
::
new
(
test_component
.clone
(),
cancel_token
,
)
.await
;
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
500
))
.await
;
let
processed_endpoints
=
metrics_aggregator
.get_endpoints
();
tracing
::
info!
(
"Found {} metrics endpoints"
,
processed_endpoints
.endpoints
.len
()
);
// Verify we found at least one metrics endpoint
assert
!
(
!
processed_endpoints
.endpoints
.is_empty
(),
"Should find at least one metrics endpoint"
);
tracing
::
info!
(
"✓ Successfully found {} metrics endpoints"
,
processed_endpoints
.endpoints
.len
()
);
// Verify the metrics endpoints contain valid data
for
(
worker_id
,
endpoint
)
in
&
processed_endpoints
.endpoints
{
tracing
::
info!
(
"✓ Worker {} metrics: {:?}"
,
worker_id
,
endpoint
.data
);
}
tracing
::
info!
(
"🎉 Event verification completed!"
);
// Cleanup
distributed
.shutdown
();
server_handle
.await
?
;
Ok
(())
}
}
lib/llm/src/mocker/evictor.rs
View file @
f0652d89
...
...
@@ -13,167 +13,158 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
cmp
::
Eq
;
use
std
::
collections
::{
HashMap
,
VecDeque
};
use
std
::
cmp
::
{
Eq
,
Ordering
}
;
use
std
::
collections
::{
BTreeSet
,
HashMap
};
use
std
::
hash
::
Hash
;
use
std
::
time
::
Instant
;
/// A wrapper for (T, counter) that implements Ord based only on counter
#[derive(Debug,
Clone,
Eq,
PartialEq)]
struct
PriorityItem
<
T
>
{
item
:
T
,
counter
:
i64
,
}
impl
<
T
:
Eq
>
Ord
for
PriorityItem
<
T
>
{
fn
cmp
(
&
self
,
other
:
&
Self
)
->
Ordering
{
self
.counter
.cmp
(
&
other
.counter
)
}
}
impl
<
T
:
Eq
>
PartialOrd
for
PriorityItem
<
T
>
{
fn
partial_cmp
(
&
self
,
other
:
&
Self
)
->
Option
<
Ordering
>
{
Some
(
self
.cmp
(
other
))
}
}
/// An LRU evictor that maintains objects and evicts them based on their
/// last accessed time. Implements a "lazy" eviction mechanism where:
/// 1. The priority queue does not immediately reflect updates or removes
/// 2. Objects are pushed to the queue in order of increasing priority (older objects first)
/// 3. The user must ensure objects are added in correct priority (temporal order)
/// 4. Remove and update operations are lazy - entries remain in the queue until
/// they are either evicted or cleaned up during maintenance
/// priority counter. Lower counter values are evicted first.
#[derive(Debug)]
pub
struct
LRUEvictor
<
T
:
Clone
+
Eq
+
Hash
>
{
free_table
:
HashMap
<
T
,
f
64
>
,
priority_queue
:
VecDeque
<
(
T
,
f64
)
>
,
cleanup_threshold
:
usize
,
start_time
:
Instant
,
free_table
:
HashMap
<
T
,
i
64
>
,
priority_queue
:
BTreeSet
<
PriorityItem
<
T
>
>
,
positive_counter
:
i64
,
negative_counter
:
i64
,
}
impl
<
T
:
Clone
+
Eq
+
Hash
>
Default
for
LRUEvictor
<
T
>
{
fn
default
()
->
Self
{
Self
{
free_table
:
HashMap
::
new
(),
priority_queue
:
VecDeque
::
new
(),
cleanup_threshold
:
5
0
,
start_time
:
Instant
::
now
()
,
priority_queue
:
BTreeSet
::
new
(),
positive_counter
:
0
,
negative_counter
:
0
,
}
}
}
impl
<
T
:
Clone
+
Eq
+
Hash
>
LRUEvictor
<
T
>
{
/// Create a new LRUEvictor with the default cleanup threshold
pub
fn
new
(
cleanup_threshold
:
usize
)
->
Self
{
Self
{
cleanup_threshold
,
..
Default
::
default
()
}
pub
fn
new
(
_
cleanup_threshold
:
usize
)
->
Self
{
Self
::
default
()
}
/// Get the current timestamp as seconds since initialization
pub
fn
current_timestamp
(
&
self
)
->
f64
{
self
.start_time
.elapsed
()
.as_secs_f64
()
pub
fn
keys
(
&
self
)
->
std
::
collections
::
hash_map
::
Keys
<
'_
,
T
,
i64
>
{
self
.free_table
.keys
()
}
/// Get an iterator over the keys in the evictor
pub
fn
keys
(
&
self
)
->
std
::
collections
::
hash_map
::
Keys
<
'_
,
T
,
f64
>
{
self
.free_table
.keys
()
fn
update
(
&
mut
self
,
object
:
T
,
counter
:
i64
)
{
self
.free_table
.insert
(
object
.clone
(),
counter
);
self
.priority_queue
.insert
(
PriorityItem
{
item
:
object
,
counter
,
});
}
/// Insert or update an object in the evictor with current timestamp
pub
fn
insert
(
&
mut
self
,
object
:
T
)
{
let
timestamp
=
self
.current_timestamp
();
self
._insert
(
object
,
timestamp
);
// Remove old entry if it exists
if
let
Some
(
&
old_counter
)
=
self
.free_table
.get
(
&
object
)
{
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
:
old_counter
,
});
}
/// Check if the evictor contains the given object
pub
fn
contains
(
&
self
,
object
:
&
T
)
->
bool
{
self
.free_table
.contains_key
(
object
)
// Increment positive counter and insert
self
.positive_counter
+=
1
;
let
counter
=
self
.positive_counter
;
self
.update
(
object
,
counter
);
}
/// Evict an object based on LRU policy
/// Returns the evicted object or None if no objects are available
pub
fn
evict
(
&
mut
self
)
->
Option
<
T
>
{
if
self
.free_table
.is_empty
()
{
return
None
;
/// Push an object to the front with negative counter (highest priority for eviction)
pub
fn
push_front
(
&
mut
self
,
object
:
T
)
{
// Remove old entry if it exists
if
let
Some
(
&
old_counter
)
=
self
.free_table
.get
(
&
object
)
{
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
:
old_counter
,
});
}
while
let
Some
((
object
,
last_accessed
))
=
self
.priority_queue
.pop_front
()
{
let
Some
(
&
current_last_accessed
)
=
self
.free_table
.get
(
&
object
)
else
{
continue
;
// entry is already removed
};
// Decrement negative counter and insert
self
.negative_counter
-=
1
;
let
counter
=
self
.negative_counter
;
if
current_last_accessed
==
last_accessed
{
self
.free_table
.remove
(
&
object
);
return
Some
(
object
);
}
// otherwise entry is stale
self
.update
(
object
,
counter
);
}
None
pub
fn
contains
(
&
self
,
object
:
&
T
)
->
bool
{
self
.free_table
.contains_key
(
object
)
}
/// Insert or update an object in the evictor
fn
_
insert
(
&
mut
self
,
object
:
T
,
last_accessed
:
f64
)
{
self
.free_table
.insert
(
object
.clone
(),
last_accessed
);
self
.priority_queue
.push_back
((
object
,
last_accessed
));
self
.cleanup_if_necessary
();
/// Evict an object based on LRU policy (lowest counter value)
/// Returns the evicted object or None if no objects are available
pub
fn
evict
(
&
mut
self
)
->
Option
<
T
>
{
self
.priority_queue
.pop_first
()
.map
(|
item
|
{
self
.free_table
.remove
(
&
item
.item
);
item
.item
})
}
/// Remove an object from the evictor
/// We don't remove from the priority queue immediately, as that would be inefficient
/// Outdated entries will be filtered out during eviction or cleanup
pub
fn
remove
(
&
mut
self
,
object
:
&
T
)
->
bool
{
self
.free_table
.remove
(
object
)
.is_some
()
let
Some
(
&
counter
)
=
self
.free_table
.get
(
object
)
else
{
return
false
;
};
self
.free_table
.remove
(
object
);
self
.priority_queue
.remove
(
&
PriorityItem
{
item
:
object
.clone
(),
counter
,
});
true
}
/// Get the number of objects in the evictor
pub
fn
len
(
&
self
)
->
usize
{
self
.free_table
.len
()
}
/// Check if the evictor is empty
pub
fn
is_empty
(
&
self
)
->
bool
{
self
.free_table
.is_empty
()
}
/// Check if cleanup is necessary and perform it if needed
fn
cleanup_if_necessary
(
&
mut
self
)
{
if
self
.priority_queue
.len
()
>
self
.cleanup_threshold
*
self
.free_table
.len
()
{
self
.cleanup
();
}
}
/// Clean up the priority queue by removing outdated entries
fn
cleanup
(
&
mut
self
)
{
let
mut
new_priority_queue
=
VecDeque
::
new
();
for
(
object
,
timestamp
)
in
self
.priority_queue
.drain
(
..
)
{
let
Some
(
&
current_timestamp
)
=
self
.free_table
.get
(
&
object
)
else
{
continue
;
};
if
current_timestamp
==
timestamp
{
new_priority_queue
.push_back
((
object
,
timestamp
));
}
}
self
.priority_queue
=
new_priority_queue
;
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
rstest
::
rstest
;
#[rstest]
#[case(
1
)]
#[case(
2
)]
#[case(
3
)]
fn
test_lru_evictor_eviction_order
(
#[case]
threshold
:
usize
)
{
// Create a new LRUEvictor with the given cleanup threshold
let
mut
evictor
=
LRUEvictor
::
<
i32
>
::
new
(
threshold
);
#[test]
fn
test_lru_evictor_eviction_order
()
{
// Create a new LRUEvictor
let
mut
evictor
=
LRUEvictor
::
<
i32
>
::
new
(
1
);
// threshold value doesn't matter anymore
// Add items in the specified order
with small delays between each
// Add items in the specified order
evictor
.insert
(
4
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
3
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
2
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
1
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
5
);
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
1
);
// Updates timestamp for 1
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
4
);
// Updates timestamp for 4
std
::
thread
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
1
));
evictor
.insert
(
2
);
// Updates timestamp for 2
evictor
.insert
(
1
);
// Updates counter for 1
evictor
.insert
(
4
);
// Updates counter for 4
evictor
.insert
(
2
);
// Updates counter for 2
evictor
.push_front
(
4
);
// Verify the eviction order
println!
(
"Testing with threshold {}"
,
threshold
);
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
4
);
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
3
);
let
evicted
=
evictor
.evict
()
.unwrap
();
...
...
@@ -181,11 +172,11 @@ mod tests {
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
1
);
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
4
);
let
evicted
=
evictor
.evict
()
.unwrap
();
assert_eq!
(
evicted
,
2
);
let
evicted
=
evictor
.evict
();
assert_eq!
(
evicted
,
None
);
assert_eq!
(
evictor
.len
(),
0
);
}
// ... existing test_push_front test ...
}
lib/llm/src/mocker/kv_manager.rs
View file @
f0652d89
...
...
@@ -46,10 +46,11 @@
//! implementation of the main block manager.
use
crate
::
mocker
::
evictor
::
LRUEvictor
;
use
crate
::
mocker
::
protocols
::{
MoveBlock
,
PrefillCost
,
UniqueBlock
};
use
crate
::
mocker
::
protocols
::{
MoveBlock
,
MoveBlockResponse
,
PrefillCost
,
UniqueBlock
};
use
crate
::
mocker
::
sequence
::
ActiveSequence
;
use
derive_getters
::
Getters
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
tokio
::
sync
::
mpsc
;
#[derive(Getters)]
pub
struct
KvManager
{
...
...
@@ -57,17 +58,27 @@ pub struct KvManager {
max_capacity
:
usize
,
#[getter(copy)]
block_size
:
u
32
,
block_size
:
u
size
,
active_blocks
:
HashMap
<
UniqueBlock
,
usize
>
,
inactive_blocks
:
LRUEvictor
<
UniqueBlock
>
,
all_blocks
:
HashSet
<
UniqueBlock
>
,
move_block_response_tx
:
Option
<
mpsc
::
UnboundedSender
<
MoveBlockResponse
>>
,
}
impl
KvManager
{
pub
fn
new
(
max_capacity
:
usize
,
block_size
:
u32
)
->
Self
{
pub
fn
new
(
max_capacity
:
usize
,
block_size
:
usize
)
->
Self
{
Self
::
new_with_sender
(
max_capacity
,
block_size
,
None
)
}
pub
fn
new_with_sender
(
max_capacity
:
usize
,
block_size
:
usize
,
move_block_response_tx
:
Option
<
mpsc
::
UnboundedSender
<
MoveBlockResponse
>>
,
)
->
Self
{
let
active_blocks
=
HashMap
::
new
();
let
inactive_blocks
=
LRUEvictor
::
default
();
let
all_blocks
=
HashSet
::
new
();
...
...
@@ -78,18 +89,46 @@ impl KvManager {
active_blocks
,
inactive_blocks
,
all_blocks
,
move_block_response_tx
,
}
}
/// Utility method to send block responses with optional reversing
fn
send_block_response
(
&
self
,
mut
blocks
:
Vec
<
u64
>
,
reverse
:
bool
,
store
:
bool
,
parent_hash
:
Option
<
u64
>
,
)
{
if
let
Some
(
ref
tx
)
=
self
.move_block_response_tx
{
if
!
blocks
.is_empty
()
{
if
reverse
{
blocks
.reverse
();
}
let
response
=
if
store
{
MoveBlockResponse
::
Store
(
blocks
,
parent_hash
)
}
else
{
MoveBlockResponse
::
Remove
(
blocks
)
};
tx
.send
(
response
)
.unwrap
();
}
}
}
/// Process a MoveBlock instruction synchronously
pub
fn
process
(
&
mut
self
,
event
:
&
MoveBlock
)
->
bool
{
match
event
{
MoveBlock
::
Use
(
hashes
,
_
)
=>
{
MoveBlock
::
Use
(
hashes
)
=>
{
let
mut
blocks_stored
=
Vec
::
<
u64
>
::
new
();
let
mut
parent_block
:
Option
<&
UniqueBlock
>
=
None
;
for
hash
in
hashes
{
// First check if it already exists in active blocks
if
let
Some
(
ref_count
)
=
self
.active_blocks
.get_mut
(
hash
)
{
// Block already active, just increment reference count
*
ref_count
+=
1
;
parent_block
=
Some
(
hash
);
continue
;
}
...
...
@@ -97,6 +136,7 @@ impl KvManager {
if
self
.inactive_blocks
.remove
(
hash
)
{
// Insert into active with reference count 1
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
parent_block
=
Some
(
hash
);
continue
;
}
...
...
@@ -106,30 +146,53 @@ impl KvManager {
// If at max capacity, evict the oldest entry from inactive blocks
if
active_count
+
inactive_count
>=
self
.max_capacity
{
if
let
Some
(
evicted
)
=
self
.inactive_blocks
.evict
()
{
// Remove evicted block from all_blocks
self
.all_blocks
.remove
(
&
evicted
);
}
else
{
// Cannot evict block, meaning no free blocks left in inactive pool
// Send a signal, scheduler would expect to handle preemption upon receiving this
let
Some
(
evicted
)
=
self
.inactive_blocks
.evict
()
else
{
return
false
;
};
self
.all_blocks
.remove
(
&
evicted
);
if
let
UniqueBlock
::
FullBlock
(
evicted_full_block
)
=
evicted
{
self
.send_block_response
(
vec!
[
evicted_full_block
],
false
,
false
,
None
);
}
}
// Now insert the new block in active blocks with reference count 1
self
.active_blocks
.insert
(
hash
.clone
(),
1
);
// Add to all_blocks as it's a new block
self
.all_blocks
.insert
(
hash
.clone
());
if
self
.move_block_response_tx
.is_some
()
{
if
let
UniqueBlock
::
FullBlock
(
stored_full_block
)
=
hash
{
blocks_stored
.push
(
*
stored_full_block
);
}
}
}
let
parent_hash
=
match
parent_block
{
None
=>
None
,
Some
(
UniqueBlock
::
FullBlock
(
block
))
=>
Some
(
*
block
),
Some
(
UniqueBlock
::
PartialBlock
(
_
))
=>
panic!
(
"parent block cannot be partial"
),
};
self
.send_block_response
(
blocks_stored
,
false
,
true
,
parent_hash
);
}
MoveBlock
::
Destroy
(
hashes
)
=>
{
let
mut
blocks_destroyed
=
Vec
::
<
u64
>
::
new
();
// Loop in inverse direction
for
hash
in
hashes
.iter
()
.rev
()
{
self
.active_blocks
.remove
(
hash
)
.unwrap
();
// Remove from all_blocks when destroyed
assert
!
(
self
.all_blocks
.remove
(
hash
));
// Track blocks for batch sending
if
self
.move_block_response_tx
.is_some
()
{
if
let
UniqueBlock
::
FullBlock
(
destroyed_full_block
)
=
hash
{
blocks_destroyed
.push
(
*
destroyed_full_block
);
}
}
}
self
.send_block_response
(
blocks_destroyed
,
true
,
false
,
None
);
}
MoveBlock
::
Deref
(
hashes
)
=>
{
// Loop in inverse direction
for
hash
in
hashes
.iter
()
.rev
()
{
...
...
@@ -149,15 +212,15 @@ impl KvManager {
}
}
}
MoveBlock
::
Promote
(
uuid
,
hash
)
=>
{
MoveBlock
::
Promote
(
uuid
,
hash
,
parent_hash
)
=>
{
let
uuid_block
=
UniqueBlock
::
PartialBlock
(
*
uuid
);
let
hash_block
=
UniqueBlock
::
FullBlock
(
*
hash
);
let
Some
(
ref_count
)
=
self
.active_blocks
.remove
(
&
uuid_block
)
else
{
let
in_all_blocks
=
self
.all_blocks
.contains
(
&
uuid_block
);
panic!
(
"Missing active block for promotion: {:?}. Block still exists: {}"
,
uuid_block
,
in_all_blocks
"Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}"
);
};
...
...
@@ -167,6 +230,7 @@ impl KvManager {
// Update all_blocks
assert
!
(
self
.all_blocks
.remove
(
&
uuid_block
));
self
.all_blocks
.insert
(
hash_block
);
self
.send_block_response
(
vec!
[
*
hash
],
false
,
true
,
*
parent_hash
);
}
}
...
...
@@ -178,6 +242,7 @@ impl KvManager {
pub
fn
probe_new_blocks
(
&
self
,
blocks
:
&
[
UniqueBlock
])
->
usize
{
blocks
.iter
()
// .filter(|&block| !self.active_blocks.contains_key(block))
.filter
(|
&
block
|
!
self
.all_blocks
.contains
(
block
))
.count
()
}
...
...
@@ -200,6 +265,11 @@ impl KvManager {
self
.active_blocks
.len
()
}
/// Get the percentage of active blocks relative to maximum capacity
pub
fn
get_active_perc
(
&
self
)
->
f64
{
self
.active_blocks
.len
()
as
f64
/
self
.max_capacity
as
f64
}
/// Get the number of inactive blocks
pub
fn
num_inactive_blocks
(
&
self
)
->
usize
{
self
.inactive_blocks
.len
()
...
...
@@ -216,63 +286,28 @@ impl KvManager {
}
/// Check if a sequence can be scheduled and calculate cost if possible
pub
fn
try_schedule
(
&
self
,
sequence
:
&
ActiveSequence
,
watermark
:
f64
,
tokens_budget
:
usize
,
)
->
Option
<
PrefillCost
>
{
// Return None immediately if tokens_budget is 0
if
tokens_budget
==
0
{
return
None
;
}
// Get unique blocks from the sequence
let
unique_blocks
=
sequence
.unique_blocks
();
// Get the count of new blocks
let
new_blocks
=
self
.probe_new_blocks
(
unique_blocks
);
// Calculate current usage and available capacity
let
active_count
=
self
.active_blocks
.len
();
// Check if we can schedule based on the watermark
if
(
active_count
+
new_blocks
)
as
f64
>
(
1.0
-
watermark
)
*
self
.max_capacity
as
f64
{
return
None
;
}
// Calculate overlap blocks
let
overlap_blocks
=
unique_blocks
.len
()
-
new_blocks
;
// Calculate new tokens
let
new_tokens
=
sequence
.num_input_tokens
()
-
overlap_blocks
*
(
self
.block_size
as
usize
);
// // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
// new_tokens,
// sequence.num_input_tokens(),
// overlap_blocks,
// self.block_size);
// Return None if new_tokens exceeds tokens_budget
if
new_tokens
>
tokens_budget
{
return
None
;
}
pub
fn
get_prefill_cost
(
&
self
,
sequence
:
&
ActiveSequence
)
->
PrefillCost
{
let
seq_blocks
=
sequence
.unique_blocks
();
let
new_blocks
=
self
.probe_new_blocks
(
seq_blocks
);
let
overlap_blocks
=
seq_blocks
.len
()
-
new_blocks
;
let
new_tokens
=
sequence
.num_input_tokens
()
-
overlap_blocks
*
self
.block_size
;
// Calculate prefill compute
let
prefill_compute
=
new_tokens
as
f64
*
(
new_tokens
+
overlap_blocks
*
(
self
.block_size
as
usize
))
as
f64
;
1.25e-6
*
(
new_tokens
as
f64
)
.powi
(
2
)
+
7.41e-2
*
(
new_tokens
as
f64
)
+
2.62e1
;
Some
(
PrefillCost
{
PrefillCost
{
new_blocks
,
new_tokens
,
prefill_compute
,
}
)
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
tokio
::
sync
::
mpsc
;
#[test]
fn
test_failure_on_max_capacity
()
{
...
...
@@ -282,7 +317,7 @@ mod tests {
// Helper function to use multiple blocks that returns the response
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
->
bool
{
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
manager
.process
(
&
MoveBlock
::
Use
(
blocks
,
None
))
manager
.process
(
&
MoveBlock
::
Use
(
blocks
))
}
// First use 10 blocks (0 to 9) in a batch
...
...
@@ -301,15 +336,17 @@ mod tests {
}
#[test]
// This is taken directly from the example in the vllm v1 prefix caching docs
fn
test_block_lifecycle_stringent
()
{
// Create a KvManager with 10 blocks capacity
let
mut
manager
=
KvManager
::
new
(
10
,
16
);
// Create a channel to listen to block responses
let
(
tx
,
mut
rx
)
=
mpsc
::
unbounded_channel
::
<
MoveBlockResponse
>
();
// Create a KvManager with 10 blocks capacity and the response sender
let
mut
manager
=
KvManager
::
new_with_sender
(
10
,
16
,
Some
(
tx
));
// Helper function to use multiple blocks
fn
use_blocks
(
manager
:
&
mut
KvManager
,
ids
:
Vec
<
u64
>
)
{
let
blocks
=
ids
.into_iter
()
.map
(
UniqueBlock
::
FullBlock
)
.collect
();
manager
.process
(
&
MoveBlock
::
Use
(
blocks
,
None
));
manager
.process
(
&
MoveBlock
::
Use
(
blocks
));
}
// Helper function to destroy multiple blocks
...
...
@@ -324,6 +361,56 @@ mod tests {
manager
.process
(
&
MoveBlock
::
Deref
(
blocks
));
}
// Helper function to assert block responses
fn
assert_block_response
(
rx
:
&
mut
mpsc
::
UnboundedReceiver
<
MoveBlockResponse
>
,
expected_type
:
&
str
,
expected_blocks
:
Vec
<
u64
>
,
description
:
&
str
,
)
{
let
response
=
rx
.try_recv
()
.unwrap_or_else
(|
_
|
panic!
(
"Expected {expected_type} response {description}"
));
match
(
&
response
,
expected_type
)
{
(
MoveBlockResponse
::
Store
(
blocks
,
_
parent_hash
),
"Store"
)
=>
{
assert_eq!
(
blocks
.len
(),
expected_blocks
.len
(),
"Expected {} blocks in Store response {}"
,
expected_blocks
.len
(),
description
);
assert_eq!
(
*
blocks
,
expected_blocks
,
"Store blocks don't match expected {description}"
);
}
(
MoveBlockResponse
::
Remove
(
blocks
),
"Remove"
)
=>
{
assert_eq!
(
blocks
.len
(),
expected_blocks
.len
(),
"Expected {} blocks in Remove response {}"
,
expected_blocks
.len
(),
description
);
assert_eq!
(
*
blocks
,
expected_blocks
,
"Remove blocks don't match expected {description}"
);
}
_
=>
panic!
(
"Expected {expected_type} response, got {response:?} {description}"
),
}
}
// Helper function to assert no response is received
fn
assert_no_response
(
rx
:
&
mut
mpsc
::
UnboundedReceiver
<
MoveBlockResponse
>
,
description
:
&
str
,
)
{
assert
!
(
rx
.try_recv
()
.is_err
(),
"Expected no response {description}"
,);
}
// Helper function to check if active blocks contain expected blocks with expected ref counts
fn
assert_active_blocks
(
manager
:
&
KvManager
,
expected_blocks
:
&
[(
u64
,
usize
)])
{
assert_eq!
(
...
...
@@ -336,14 +423,12 @@ mod tests {
let
block
=
UniqueBlock
::
FullBlock
(
id
);
assert
!
(
manager
.active_blocks
()
.contains_key
(
&
block
),
"Block {} not found in active blocks"
,
id
"Block {id} not found in active blocks"
,
);
assert_eq!
(
manager
.active_blocks
()
.get
(
&
block
),
Some
(
&
ref_count
),
"Block {} has wrong reference count"
,
id
"Block {id} has wrong reference count"
,
);
}
}
...
...
@@ -366,17 +451,18 @@ mod tests {
let
block
=
UniqueBlock
::
FullBlock
(
id
);
assert
!
(
inactive_blocks
.iter
()
.any
(|
&
b
|
*
b
==
block
),
"Block {} not found in inactive blocks"
,
id
"Block {id} not found in inactive blocks"
,
);
}
}
// First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks
(
&
mut
manager
,
(
0
..
5
)
.collect
());
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
0
,
1
,
2
,
3
,
4
],
"after first use"
);
// Then use blocks 0, 1, 5, 6 in a batch
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
,
6
]);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
5
,
6
],
"after second use"
);
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks
(
...
...
@@ -386,9 +472,11 @@ mod tests {
// Now destroy block 4
destroy_blocks
(
&
mut
manager
,
vec!
[
4
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
4
],
"after destroy block 4"
);
// And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
3
]);
assert_no_response
(
&
mut
rx
,
"after deref operation"
);
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
2
]);
...
...
@@ -396,6 +484,7 @@ mod tests {
// Now destroy block 6
destroy_blocks
(
&
mut
manager
,
vec!
[
6
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
6
],
"after block 6 eviction"
);
// And deref blocks 5, 1, 0 as a batch
deref_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
5
]);
...
...
@@ -406,6 +495,7 @@ mod tests {
// Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks
(
&
mut
manager
,
vec!
[
0
,
1
,
2
,
7
,
8
,
9
]);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
7
,
8
,
9
],
"after [7, 8, 9] use"
);
// Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks
(
&
manager
,
2
,
&
[
3
,
5
]);
...
...
@@ -420,8 +510,14 @@ mod tests {
// Now use blocks 10, 11, 12 as a batch
use_blocks
(
&
mut
manager
,
vec!
[
10
,
11
,
12
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
3
],
"after block 5 eviction"
);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
10
,
11
,
12
],
"after [10, 11, 12] use"
);
// Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks
(
&
manager
,
1
,
&
[
5
]);
use_blocks
(
&
mut
manager
,
vec!
[
13
]);
assert_block_response
(
&
mut
rx
,
"Remove"
,
vec!
[
5
],
"after block 5 eviction"
);
assert_block_response
(
&
mut
rx
,
"Store"
,
vec!
[
13
],
"after block 13 use"
);
}
}
lib/llm/src/mocker/protocols.rs
View file @
f0652d89
...
...
@@ -13,12 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
derive_builder
::
Builder
;
use
serde
::{
Deserialize
,
Serialize
};
use
uuid
::
Uuid
;
use
crate
::
kv_router
::
protocols
::{
ExternalSequenceBlockHash
,
KvCacheEventData
,
KvCacheRemoveData
,
KvCacheStoreData
,
KvCacheStoredBlockData
,
LocalBlockHash
,
};
pub
type
Token
=
u32
;
pub
type
LocalBlockHash
=
u64
;
/// A global hash identifier for blocks
pub
type
GlobalHash
=
u64
;
pub
type
NumBlocks
=
usize
;
...
...
@@ -39,12 +43,19 @@ impl Default for UniqueBlock {
}
/// Represents different block movement operations in the cache
/// For Use and Promote variants, parent hash is the second field
#[derive(Debug,
Clone,
PartialEq,
Serialize,
Deserialize)]
pub
enum
MoveBlock
{
Use
(
Vec
<
UniqueBlock
>
,
Option
<
f64
>
),
Use
(
Vec
<
UniqueBlock
>
),
Destroy
(
Vec
<
UniqueBlock
>
),
Deref
(
Vec
<
UniqueBlock
>
),
Promote
(
Uuid
,
GlobalHash
),
Promote
(
Uuid
,
GlobalHash
,
Option
<
u64
>
),
}
#[derive(Debug,
Clone,
PartialEq,
Serialize,
Deserialize)]
pub
enum
MoveBlockResponse
{
Store
(
Vec
<
GlobalHash
>
,
Option
<
u64
>
),
Remove
(
Vec
<
GlobalHash
>
),
}
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
...
...
@@ -52,15 +63,86 @@ pub struct DirectRequest {
pub
tokens
:
Vec
<
Token
>
,
pub
max_output_tokens
:
usize
,
pub
uuid
:
Option
<
Uuid
>
,
pub
dp_rank
:
Option
<
u32
>
,
}
/// Represents the cost of prefilling content in the cache
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
PrefillCost
{
pub
new_blocks
:
usize
,
pub
new_tokens
:
usize
,
pub
prefill_compute
:
f64
,
}
/// Signal for output token generation with completion status
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
OutputSignal
{
pub
uuid
:
Uuid
,
pub
completed
:
bool
,
}
/// Configuration arguments for MockVllmEngine
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Builder)]
#[builder(pattern
=
"owned"
,
build_fn(public))]
pub
struct
MockEngineArgs
{
#[builder(default
=
"16384"
)]
pub
num_gpu_blocks
:
usize
,
#[builder(default
=
"64"
)]
pub
block_size
:
usize
,
// This was 1024 in the past but reverted back to 256
#[builder(default
=
Some(
256
))]
pub
max_num_seqs
:
Option
<
usize
>
,
// default for open api server, for llm class it's 16384
#[builder(default
=
Some(
8192
))]
pub
max_num_batched_tokens
:
Option
<
usize
>
,
#[builder(default
=
true
)]
pub
enable_prefix_caching
:
bool
,
#[builder(default
=
"0.01"
)]
pub
watermark
:
f64
,
#[builder(default
=
"1.0"
)]
pub
speedup_ratio
:
f64
,
#[builder(default
=
"1"
)]
pub
dp_size
:
u32
,
}
impl
MockEngineArgs
{
pub
fn
builder
()
->
MockEngineArgsBuilder
{
MockEngineArgsBuilder
::
default
()
}
}
/// Note: This assumes block_hash and tokens_hash are the same, which is not correct in rare cases
/// where the sequence-aware hash differs from the token content hash.
pub
fn
block_response_to_kv_event
(
response
:
MoveBlockResponse
)
->
KvCacheEventData
{
match
response
{
MoveBlockResponse
::
Store
(
full_blocks
,
parent_hash
)
=>
{
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
parent_hash
.map
(
ExternalSequenceBlockHash
),
blocks
:
full_blocks
.into_iter
()
.map
(|
block
|
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
block
),
tokens_hash
:
LocalBlockHash
(
block
),
})
.collect
(),
})
}
MoveBlockResponse
::
Remove
(
full_blocks
)
=>
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
full_blocks
.into_iter
()
.map
(
ExternalSequenceBlockHash
)
.collect
(),
}),
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
...
...
lib/llm/src/mocker/scheduler.rs
View file @
f0652d89
...
...
@@ -40,11 +40,13 @@
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
crate
::
kv_router
::
protocols
::
{
ForwardPassMetrics
,
KvCacheEventData
}
;
use
crate
::
mocker
::
evictor
::
LRUEvictor
;
use
crate
::
mocker
::
kv_manager
::
KvManager
;
use
crate
::
mocker
::
protocols
::
DirectRequest
;
use
crate
::
mocker
::
protocols
::{
MoveBlock
,
PrefillCost
,
UniqueBlock
};
use
crate
::
mocker
::
protocols
::{
block_response_to_kv_event
,
MoveBlock
,
OutputSignal
,
PrefillCost
,
UniqueBlock
,
};
use
crate
::
mocker
::
protocols
::{
DirectRequest
,
MockEngineArgs
,
MoveBlockResponse
};
use
crate
::
mocker
::
sequence
::
ActiveSequence
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
VecDeque
;
...
...
@@ -63,8 +65,8 @@ pub enum Request {
#[derive(Default)]
struct
SchedulerState
{
waiting
:
VecDeque
<
Uuid
>
,
re
ady
:
VecDeque
<
Uuid
>
,
running
:
LRUEvictor
<
Uuid
>
,
p
re
fill
:
VecDeque
<
Uuid
>
,
decode
:
LRUEvictor
<
Uuid
>
,
requests
:
HashMap
<
Uuid
,
Request
>
,
prefill_costs
:
HashMap
<
Uuid
,
Option
<
PrefillCost
>>
,
}
...
...
@@ -74,61 +76,70 @@ impl SchedulerState {
fn
receive
(
&
mut
self
,
request
:
DirectRequest
)
->
Uuid
{
// Use the provided UUID if available, otherwise generate a new one
let
uuid
=
request
.uuid
.unwrap_or_else
(
Uuid
::
new_v4
);
// Add the request to the map and waiting queue
self
.requests
.insert
(
uuid
,
Request
::
Direct
(
request
));
self
.waiting
.push_back
(
uuid
);
uuid
}
/// Get the next UUID from ready or waiting queue and its associated Request.
/// Returns from ready if not empty, otherwise from waiting, or None if both are empty.
/// Also removes the Request from the requests HashMap.
fn
next
(
&
mut
self
)
->
Option
<
(
Uuid
,
Request
)
>
{
let
uuid
=
self
.ready
.
pop_front
()
.
or_else
(||
self
.waiting
.pop_front
())
?
;
let
request
=
self
.requests
.remove
(
&
uuid
)
?
;
let
uuid
=
self
.waiting
.pop_front
()
?
;
let
request
=
self
.
requests
.
remove
(
&
uuid
)
.expect
(
"Request does not exist."
)
;
Some
((
uuid
,
request
))
}
/// Move a UUID and its Request to the waiting queue (front).
fn
first_in_line
(
&
mut
self
,
uuid
:
Uuid
,
request
:
Request
)
{
self
.requests
.insert
(
uuid
,
request
);
self
.waiting
.push_front
(
uuid
);
}
/// Move a UUID and its Request to the ready queue.
fn
make_ready
(
&
mut
self
,
uuid
:
Uuid
,
active_seq
:
ActiveSequence
)
{
fn
start_prefill
(
&
mut
self
,
uuid
:
Uuid
,
active_seq
:
ActiveSequence
,
cost
:
Option
<
PrefillCost
>
)
{
self
.requests
.insert
(
uuid
,
Request
::
Active
(
active_seq
));
self
.ready
.push_back
(
uuid
);
self
.prefill
.push_back
(
uuid
);
self
.prefill_costs
.insert
(
uuid
,
cost
);
}
/// Schedule the request with the given UUID.
/// Returns the creation signal from the ActiveSequence.
fn
run
(
&
mut
self
,
uuid
:
Uuid
,
active_seq
:
ActiveSequence
)
->
MoveBlock
{
// Insert the request into the map
self
.requests
.insert
(
uuid
,
Request
::
Active
(
active_seq
));
/// Pop from prefill queue and move to decode queue.
/// Returns the prefill_compute value if available.
fn
start_decode
(
&
mut
self
)
->
Option
<
(
f64
,
MoveBlock
)
>
{
let
uuid
=
self
.prefill
.pop_front
()
?
;
self
.decode
.insert
(
uuid
);
// Remove and extract prefill_compute from prefill_costs
let
prefill_cost
=
self
.prefill_costs
.remove
(
&
uuid
)
.flatten
()
.expect
(
"Expects valid prefill cost."
);
// Get the creation signal
let
Some
(
Request
::
Active
(
sequence
))
=
self
.requests
.get
(
&
uuid
)
else
{
panic!
(
"Failed to get ActiveSequence for UUID"
);
};
let
Some
(
signal
)
=
sequence
.creation_signal
()
else
{
panic!
(
"Failed to get creation signal from ActiveSequence"
);
panic!
(
"Request does not exist."
);
};
let
creation_signal
=
sequence
.creation_signal
()
.clone
()
.expect
(
"Must have creation signal."
);
// Add to running requests
self
.running
.insert
(
uuid
);
signal
.clone
()
Some
((
prefill_cost
.prefill_compute
,
creation_signal
))
}
/// Set the prefill cost for a UUID
fn
set_prefill_cost
(
&
mut
self
,
uuid
:
Uuid
,
cost
:
Option
<
PrefillCost
>
)
{
self
.prefill_costs
.insert
(
uuid
,
cost
);
fn
run
(
&
mut
self
,
uuid
:
Uuid
)
->
Option
<&
mut
ActiveSequence
>
{
if
!
self
.decode
.contains
(
&
uuid
)
{
return
None
;
}
let
Some
(
Request
::
Active
(
sequence
))
=
self
.requests
.get_mut
(
&
uuid
)
else
{
panic!
(
"Request does not exist."
);
};
Some
(
sequence
)
}
/// Get the prefill compute value for a UUID if available
fn
get_prefill_compute
(
&
self
,
uuid
:
&
Uuid
)
->
Option
<
f64
>
{
self
.prefill_costs
.get
(
uuid
)
.and_then
(|
cost
|
cost
.as_ref
())
.map
(|
cost
|
cost
.prefill_compute
)
fn
num_active_requests
(
&
self
)
->
usize
{
self
.prefill
.len
()
+
self
.decode
.len
()
}
/// Calculate the current running batched tokens
...
...
@@ -145,7 +156,7 @@ impl SchedulerState {
/// Remove a UUID and its associated Request from collections.
fn
complete
(
&
mut
self
,
uuid
:
&
Uuid
)
{
// println!("Request {} will complete", uuid);
self
.
running
.remove
(
uuid
);
self
.
decode
.remove
(
uuid
);
self
.requests
.remove
(
uuid
);
self
.prefill_costs
.remove
(
uuid
);
}
...
...
@@ -153,76 +164,93 @@ impl SchedulerState {
/// Preempt the oldest running request by evicting it from running, resetting the sequence,
/// and adding it back to the waiting queue.
/// Returns the signal from reset_with_signal or None if no requests are running.
fn
preempt
(
&
mut
self
)
->
Option
<
Vec
<
MoveBlock
>
>
{
fn
preempt
(
&
mut
self
)
->
Vec
<
MoveBlock
>
{
// Evict the oldest UUID from running
let
uuid
=
self
.running
.evict
()
?
;
eprintln!
(
"Request {} will be preempted"
,
uuid
);
// Remove the request from the requests HashMap and ensure it's an ActiveSequence
let
request
=
self
.requests
.remove
(
&
uuid
)
?
;
// Remove the prefill cost to force recomputation
let
uuid
=
self
.decode
.evict
()
.expect
(
"Nothing to evict for preemption."
);
let
request
=
self
.requests
.remove
(
&
uuid
)
.expect
(
"Request does not exist."
);
self
.prefill_costs
.remove
(
&
uuid
);
eprintln!
(
"Request {uuid} will be preempted"
);
// Extract the ActiveSequence from the Request enum
// Reset the sequence and get the new sequence and signal
// Insert the new sequence back into the requests map and add to waiting queue
let
Request
::
Active
(
mut
active_sequence
)
=
request
else
{
panic!
(
"Expected ActiveSequence in running queue"
)
};
// Reset the sequence and get the new sequence and signal
let
signals
=
active_sequence
.reset_with_signal
();
// Insert the new sequence back into the requests map and add to waiting queue
self
.requests
.insert
(
uuid
,
Request
::
Active
(
active_sequence
));
self
.waiting
.push_back
(
uuid
);
// Note: For preemption, we don't compute hit rate since we don't have access to new_tokens
// and the sequence is being reset anyway. Hit rate tracking is primarily for new scheduling attempts.
self
.first_in_line
(
uuid
,
Request
::
Active
(
active_sequence
));
Some
(
signals
)
signals
}
}
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub
struct
Scheduler
{
dp_rank
:
Option
<
u32
>
,
state
:
Arc
<
Mutex
<
SchedulerState
>>
,
kv_manager
:
Arc
<
Mutex
<
KvManager
>>
,
request_tx
:
mpsc
::
Sender
<
DirectRequest
>
,
request_tx
:
mpsc
::
UnboundedSender
<
DirectRequest
>
,
hit_rates
:
Arc
<
Mutex
<
VecDeque
<
f32
>>>
,
}
impl
Scheduler
{
/// Create a new Scheduler with the given parameters
pub
fn
new
(
kv_capacity
:
usize
,
watermark
:
f64
,
block_size
:
u32
,
chunk_size
:
Option
<
usize
>
,
output_tx
:
Option
<
mpsc
::
Sender
<
Uuid
>>
,
args
:
MockEngineArgs
,
dp_rank
:
Option
<
u32
>
,
output_tx
:
Option
<
mpsc
::
UnboundedSender
<
OutputSignal
>>
,
kv_events_tx
:
Option
<
mpsc
::
UnboundedSender
<
KvCacheEventData
>>
,
cancellation_token
:
Option
<
CancellationToken
>
,
)
->
Self
{
// Create KvManager internally
let
kv_manager
=
KvManager
::
new
(
kv_capacity
,
block_size
);
let
token_capacity
:
usize
=
8192
;
let
state
=
Arc
::
new
(
Mutex
::
new
(
SchedulerState
::
default
()));
let
kv_manager
=
Arc
::
new
(
Mutex
::
new
(
kv_manager
));
let
chunk_size
=
chunk_size
.unwrap_or
(
256
);
// Create internal channel for KV events only if needed
let
(
block_resp_tx
,
mut
block_resp_rx
)
=
if
kv_events_tx
.is_some
()
{
let
(
tx
,
rx
)
=
mpsc
::
unbounded_channel
::
<
MoveBlockResponse
>
();
(
Some
(
tx
),
Some
(
rx
))
}
else
{
(
None
,
None
)
};
let
kv_manager
=
Arc
::
new
(
Mutex
::
new
(
KvManager
::
new_with_sender
(
args
.num_gpu_blocks
,
args
.block_size
,
block_resp_tx
,
)));
let
hit_rates
=
Arc
::
new
(
Mutex
::
new
(
VecDeque
::
with_capacity
(
1000
)));
// Create channel for request handling
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
channel
::
<
DirectRequest
>
(
1024
);
// Assert speedup_ratio is greater than 0
assert
!
(
args
.speedup_ratio
>
0.0
,
"speedup_ratio must be greater than 0, got: {}"
,
args
.speedup_ratio
);
// Use provided cancellation token or create new one
let
cancellation_token
=
cancellation_token
.unwrap_or_default
();
let
token_clone
=
cancellation_token
.clone
();
// Create channel for request handling
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
DirectRequest
>
();
// Create a clone for the background task
let
state_clone
=
state
.clone
();
let
kv_manager_clone
=
kv_manager
.clone
();
let
output_tx_clone
=
output_tx
.clone
();
let
cancel_token_clone
=
cancellation_token
.unwrap_or_default
()
.clone
();
let
hit_rates_clone
=
hit_rates
.clone
();
// Spawn main background task with cancellation token
tokio
::
spawn
(
async
move
{
let
mut
schedule_interval
=
interval
(
Duration
::
from_millis
(
5
));
let
mut
simulate_interval
=
interval
(
Duration
::
from_millis
(
1
));
let
mut
schedule_interval
=
interval
(
Duration
::
from_secs_f64
(
1e-3
));
let
mut
simulate_interval
=
interval
(
Duration
::
from_secs_f64
(
1e-4
));
let
mut
should_schedule
=
true
;
loop
{
tokio
::
select!
{
...
...
@@ -234,35 +262,63 @@ impl Scheduler {
state
.receive
(
request
);
}
// Try Scheduling Requests
// Try Scheduling Requests
- runs on normal interval or after simulation
_
=
schedule_interval
.tick
()
=>
{
// Skip if we just ran scheduling after simulation to prevent consecutive runs
if
!
should_schedule
{
continue
;
}
let
mut
state_guard
=
state_clone
.lock
()
.await
;
let
mut
kv_manager_guard
=
kv_manager_clone
.lock
()
.await
;
let
kv_manager_guard
=
kv_manager_clone
.lock
()
.await
;
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore.
let
mut
current_blocks
=
kv_manager_guard
.num_active_blocks
();
let
mut
current_tokens
=
state_guard
.num_batched_tokens
();
let
mut
current_seqs
=
state_guard
.num_active_requests
();
while
let
Some
((
uuid
,
request
))
=
state_guard
.next
()
{
let
active_sequence
=
get_active_sequence
(
request
,
block_size
,
chunk_size
);
let
active_sequence
=
get_active_sequence
(
request
,
args
.
block_size
,
args
.enable_prefix_caching
);
// Calculate token budget using new_tokens from PrefillCost
let
total_prefill_tokens
=
state_guard
.num_batched_tokens
();
let
tokens_budget
=
token_capacity
.saturating_sub
(
total_prefill_tokens
);
// Update predictive budgets
let
prefill_cost
=
kv_manager_guard
.get_prefill_cost
(
&
active_sequence
);
let
total_tokens
=
active_sequence
.len
();
let
new_blocks
=
(
total_tokens
+
1
)
/
args
.block_size
;
// this is conservative, assumes no cache hit
let
new_tokens
=
prefill_cost
.new_tokens
;
current_blocks
+=
new_blocks
;
current_tokens
+=
new_tokens
;
current_seqs
+=
1
;
// Check if it can be scheduled
let
Some
(
prefill_cost
)
=
kv_manager_guard
.try_schedule
(
&
active_sequence
,
watermark
,
tokens_budget
)
else
{
state_guard
.make_ready
(
uuid
,
active_sequence
);
let
under_block_budget
=
current_blocks
as
f64
<=
(
1
.
-
args
.watermark
)
*
kv_manager_guard
.max_capacity
()
as
f64
;
let
under_token_budget
=
args
.max_num_batched_tokens
.is_none_or
(|
limit
|
current_tokens
<=
limit
);
let
under_seq_budget
=
args
.max_num_seqs
.is_none_or
(|
limit
|
current_seqs
<=
limit
);
// Cannot schedule, put first in line instead
if
!
(
under_block_budget
&&
under_token_budget
&&
under_seq_budget
)
{
state_guard
.first_in_line
(
uuid
,
Request
::
Active
(
active_sequence
));
break
;
}
;
}
// Get creation signal and schedule the request
let
signal
=
state_guard
.run
(
uuid
,
active_sequence
);
kv_manager_guard
.process
(
&
signal
);
state_guard
.set_prefill_cost
(
uuid
,
Some
(
prefill_cost
));
// Compute and store hit rate
let
hit_rate
=
if
!
active_sequence
.is_empty
()
{
1.0
-
(
new_tokens
as
f32
/
active_sequence
.len
()
as
f32
)
}
else
{
0.0
};
{
let
mut
hit_rates_guard
=
hit_rates_clone
.lock
()
.await
;
hit_rates_guard
.push_back
(
hit_rate
);
if
hit_rates_guard
.len
()
>
1000
{
hit_rates_guard
.pop_front
();
}
}
state_guard
.start_prefill
(
uuid
,
active_sequence
,
Some
(
prefill_cost
));
should_schedule
=
false
;
}
}
// Check for cancellation
_
=
token_clone
.cancelled
()
=>
{
_
=
cancel_
token_clone
.cancelled
()
=>
{
break
;
}
...
...
@@ -271,75 +327,84 @@ impl Scheduler {
let
mut
state_guard
=
state_clone
.lock
()
.await
;
let
mut
kv_manager_guard
=
kv_manager_clone
.lock
()
.await
;
// Base time needed for decoding
(assumed memory bound on KV cache)
let
active_
tokens
=
kv_manager_guard
.
num
_active_
blocks
()
*
(
block_size
as
usize
);
// TODO: 2 is a dummy / magic scaling factor
let
mut
generation
_time
=
Duration
::
from_
micros
((
active_tokens
/
2
)
as
u64
);
// Base time needed for decoding
using active percentage and quadratic formula
let
active_
perc
=
kv_manager_guard
.
get
_active_
perc
(
);
let
decoding_time
=
-
5.47
*
active_perc
.powi
(
2
)
+
43.88
*
active_perc
+
19.44
;
let
mut
total
_time
=
Duration
::
from_
secs_f64
(
decoding_time
/
1000.0
);
// Process each running request
let
uuids
:
Vec
<
Uuid
>
=
state_guard
.running
.keys
()
.cloned
()
.collect
();
for
uuid
in
uuids
{
// Check if UUID is still in running_requests, if not skip this iteration
if
!
state_guard
.running
.contains
(
&
uuid
)
{
continue
;
// Process prefilling
while
let
Some
((
prefill_compute
,
creation_signal
))
=
state_guard
.start_decode
()
{
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
total_time
+=
Duration
::
from_secs_f64
(
prefill_compute
/
1000.0
);
let
prefill_success
=
process_signals
(
&
mut
kv_manager_guard
,
std
::
slice
::
from_ref
(
&
creation_signal
));
if
!
prefill_success
{
panic!
(
"Block allocation for prefilling cannot fail."
);
}
//
Get prefill compute value first
let
prefill_compute
=
state_guard
.get_prefill_compute
(
&
uuid
);
// Get the active sequence for this UUID
let
sequence
=
state_guard
.requests
.get_mut
(
&
uuid
)
.and_then
(|
req
|
if
let
Request
::
Active
(
seq
)
=
req
{
Some
(
seq
)
}
else
{
None
})
.expect
(
"UUID in running_requests must have a corresponding active sequence"
);
//
Drain KV events and forward to relay after prefill signal processing
if
let
(
Some
(
ref
relay_tx
),
Some
(
ref
mut
rx
))
=
(
&
kv_events_tx
,
&
mut
block_resp_rx
)
{
while
let
Ok
(
event
)
=
rx
.try_recv
()
{
let
_
=
relay_tx
.send
(
block_response_to_kv_event
(
event
));
}
}
}
// Generate token and get signals
// Process decoding
let
uuids
:
Vec
<
Uuid
>
=
state_guard
.decode
.keys
()
.cloned
()
.collect
();
if
!
uuids
.is_empty
()
{
should_schedule
=
true
};
for
uuid
in
uuids
{
let
Some
(
sequence
)
=
state_guard
.run
(
uuid
)
else
{
continue
;
};
let
signals
=
sequence
.generate
();
// Accumulate sleep duration based on prefill_compute if available
// prefill compute = (cached_tokens + new_tokens) * new_tokens
let
sleep_ms
=
if
let
Some
(
compute
)
=
prefill_compute
{
// TODO: 1024 is a dummy / magic scaling factor
(
compute
/
1024.0
)
as
u64
}
else
{
0
};
generation_time
+=
Duration
::
from_micros
(
sleep_ms
);
// Process all signals with the KvManager
// Handling of preemption on failure
if
!
process_signals
(
&
mut
kv_manager_guard
,
&
signals
)
{
sequence
.pop
();
// revert the failed generation op
// free_signal derefs the preempted blocks
let
Some
(
free_signal
)
=
state_guard
.preempt
()
else
{
panic!
(
"Failed to acquire signal to free KV blocks from preemption"
);
};
for
signal
in
free_signal
{
for
signal
in
state_guard
.preempt
()
{
kv_manager_guard
.process
(
&
signal
);
}
continue
;
}
// Send UUID notification for each generated token
// TODO: hook this up to an AsyncEngine
if
let
Some
(
tx
)
=
&
output_tx_clone
{
let
_
=
tx
.try_send
(
uuid
);
// Drain KV events and forward to relay after decode signal processing
if
let
(
Some
(
ref
relay_tx
),
Some
(
ref
mut
rx
))
=
(
&
kv_events_tx
,
&
mut
block_resp_rx
)
{
while
let
Ok
(
event
)
=
rx
.try_recv
()
{
let
_
=
relay_tx
.send
(
block_response_to_kv_event
(
event
));
}
}
// Check if we're done after generating
if
sequence
.generated_tokens
()
>=
sequence
.max_output_tokens
()
{
state_guard
.complete
(
&
uuid
);
continue
;
// Check completion and send notification
let
is_complete
=
sequence
.generated_tokens
()
>=
sequence
.max_output_tokens
();
let
should_output
=
sequence
.generated_tokens
()
>
sequence
.already_generated_tokens
();
let
mut
send_failed
=
false
;
if
should_output
{
send_failed
=
output_tx_clone
.as_ref
()
.is_some_and
(|
tx
|
{
tx
.send
(
OutputSignal
{
uuid
,
completed
:
is_complete
})
.is_err
()
});
}
if
send_failed
{
for
signal
in
&
sequence
.free_signal
()
{
kv_manager_guard
.process
(
signal
);
}
}
// Transition to decode (no prefill cost)
if
sequence
.generated_tokens
()
==
1
{
state_guard
.set_prefill_cost
(
uuid
,
None
)
;
if
send_failed
||
is_complete
{
state_guard
.complete
(
&
uuid
);
continue
;
}
}
// Sleep once for the accumulated duration
if
generation_time
.as_millis
()
>
0
{
tokio
::
time
::
sleep
(
generation_time
)
.await
;
// Sleep once for the adjusted duration
drop
(
kv_manager_guard
);
drop
(
state_guard
);
let
adjusted_time
=
Duration
::
from_secs_f64
(
total_time
.as_secs_f64
()
/
args
.speedup_ratio
);
if
adjusted_time
.as_millis
()
>
0
{
tokio
::
time
::
sleep
(
adjusted_time
)
.await
;
}
}
}
...
...
@@ -347,15 +412,22 @@ impl Scheduler {
});
Self
{
dp_rank
,
state
,
kv_manager
,
request_tx
,
hit_rates
,
}
}
/// Add a new request to the waiting queue
pub
async
fn
receive
(
&
self
,
request
:
DirectRequest
)
{
let
_
=
self
.request_tx
.send
(
request
)
.await
;
let
_
=
self
.request_tx
.send
(
request
);
}
/// Expose the sender
pub
fn
request_sender
(
&
self
)
->
mpsc
::
UnboundedSender
<
DirectRequest
>
{
self
.request_tx
.clone
()
}
/// Get the count of waiting requests
...
...
@@ -367,7 +439,7 @@ impl Scheduler {
/// Get the count of running requests
pub
async
fn
running_count
(
&
self
)
->
usize
{
let
state
=
self
.state
.lock
()
.await
;
state
.
running
.len
()
state
.
decode
.len
()
}
/// Get the current capacity of the KvManager
...
...
@@ -378,35 +450,53 @@ impl Scheduler {
/// Returns forward pass metrics for monitoring purposes
pub
async
fn
get_forward_pass_metrics
(
&
self
)
->
ForwardPassMetrics
{
// Acquire all locks in consistent order: state -> kv_manager -> hit_rates
let
state
=
self
.state
.lock
()
.await
;
let
kv_manager
=
self
.kv_manager
.lock
()
.await
;
let
hit_rates_guard
=
self
.hit_rates
.lock
()
.await
;
// Get the active blocks and total capacity from KvManager
// Get state metrics
let
request_active_slots
=
state
.decode
.len
()
as
u64
;
let
num_requests_waiting
=
state
.waiting
.len
()
as
u64
;
// Get KV manager metrics
let
active_blocks_count
=
kv_manager
.active_blocks
()
.len
()
as
u64
;
let
total_capacity
=
kv_manager
.max_capacity
()
as
u64
;
// Calculate GPU cache usage percentage
let
gpu_cache_usage_perc
=
if
total_capacity
>
0
{
active_blocks_count
as
f32
/
total_capacity
as
f32
}
else
{
0.0
};
// Get hit rate metrics
let
gpu_prefix_cache_hit_rate
=
if
hit_rates_guard
.is_empty
()
{
0.0
}
else
{
let
sum
:
f32
=
hit_rates_guard
.iter
()
.sum
();
sum
/
hit_rates_guard
.len
()
as
f32
};
ForwardPassMetrics
{
data_parallel_rank
:
None
,
// Default for backwards compatibility
request_active_slots
:
state
.running
.len
()
as
u64
,
request_total_slots
:
420
,
// Dummy value as specified
data_parallel_rank
:
self
.dp_rank
,
request_active_slots
,
// vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
request_total_slots
:
1024
,
kv_active_blocks
:
active_blocks_count
,
kv_total_blocks
:
total_capacity
,
num_requests_waiting
:
state
.waiting
.len
()
as
u64
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
:
0.0
,
// Placeholder value as specified
gpu_prefix_cache_hit_rate
,
}
// Guards drop naturally here in reverse order (LIFO): hit_rates_guard, kv_manager, state
}
}
/// Convert a Request to an ActiveSequence
fn
get_active_sequence
(
request
:
Request
,
block_size
:
u32
,
chunk_size
:
usize
)
->
ActiveSequence
{
fn
get_active_sequence
(
request
:
Request
,
block_size
:
usize
,
enable_prefix_caching
:
bool
,
)
->
ActiveSequence
{
if
let
Request
::
Active
(
active_seq
)
=
request
{
return
active_seq
;
}
...
...
@@ -419,7 +509,7 @@ fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) ->
direct_request
.tokens
,
direct_request
.max_output_tokens
,
Some
(
block_size
),
Some
(
chunk_size
)
,
enable_prefix_caching
,
)
}
...
...
@@ -440,7 +530,7 @@ fn process_signals(
}
// Check we have a Use signal with blocks
let
MoveBlock
::
Use
(
blocks
,
_
)
=
signal
else
{
let
MoveBlock
::
Use
(
blocks
)
=
signal
else
{
panic!
(
"Failed signal is Invalid. Has to fail on generation signal."
);
};
...
...
@@ -467,32 +557,37 @@ mod tests {
use
std
::
time
::
Duration
;
#[rstest]
#[case::random(
false
)]
#[case::caching(
true
)]
#[case::random_no_prefix_caching(
false
,
false
)]
#[case::random_with_prefix_caching(
false
,
true
)]
#[case::caching_no_prefix_caching(
true
,
false
)]
#[case::caching_with_prefix_caching(
true
,
true
)]
#[tokio::test]
async
fn
test_scheduler_token_generation_patterns
(
#[case]
use_shared_tokens
:
bool
)
{
async
fn
test_scheduler_token_generation_patterns
(
#[case]
use_shared_tokens
:
bool
,
#[case]
enable_prefix_caching
:
bool
,
)
{
std
::
env
::
set_var
(
"RUST_LOG"
,
"debug"
);
let
kv_capacity
:
usize
=
500
;
let
watermark
:
f64
=
0.01
;
// 1% watermark
let
block_size
:
u32
=
64
;
let
chunk_size
:
usize
=
256
;
let
block_size
:
usize
=
64
;
let
num_requests
:
usize
=
100
;
let
input_len
:
usize
=
1000
;
let
max_output_tokens
:
usize
=
100
;
// Create channel for token output
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
channel
::
<
Uuid
>
(
1024
);
// Create scheduler with internal KvManager
let
scheduler
=
Scheduler
::
new
(
kv_capacity
,
watermark
,
block_size
,
Some
(
chunk_size
),
Some
(
output_tx
),
None
,
);
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
// Create scheduler args using builder - now including enable_prefix_caching
let
args
=
MockEngineArgs
::
builder
()
.num_gpu_blocks
(
kv_capacity
)
.block_size
(
block_size
)
.speedup_ratio
(
10.0
)
.enable_prefix_caching
(
enable_prefix_caching
)
.build
()
.unwrap
();
// Create scheduler with new args struct
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
// Create shared tokens for caching case
let
shared_tokens
=
if
use_shared_tokens
{
...
...
@@ -523,6 +618,7 @@ mod tests {
tokens
:
input_tokens
,
max_output_tokens
,
uuid
:
None
,
dp_rank
:
None
,
};
scheduler
.receive
(
request
)
.await
;
}
...
...
@@ -547,7 +643,7 @@ mod tests {
// Manual debug ticker that prints forward pass metrics
_
=
debug_interval
.tick
()
=>
{
let
_
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
//
println!("Forward Pass Metrics: {
:#?}",
_metrics);
println!
(
"Forward Pass Metrics: {_metrics
:#?}"
);
}
Some
(
_
)
=
output_rx
.recv
()
=>
{
...
...
@@ -566,21 +662,177 @@ mod tests {
// Calculate and print elapsed time
let
elapsed
=
start_time
.elapsed
();
println!
(
"Test completed in: {:?} for {} case"
,
"Test completed in: {:?} for {} case
with prefix_caching={}
"
,
elapsed
,
if
use_shared_tokens
{
"caching"
}
else
{
"random"
}
},
enable_prefix_caching
);
// Assert that we received the expected number of tokens
assert
!
(
received_tokens
>
expected_tokens
,
"Received {} tokens but expected more than {}"
,
received_tokens
,
expected_tokens
received_tokens
==
expected_tokens
,
"Received {received_tokens} tokens but expected exactly {expected_tokens}"
);
}
#[tokio::test]
async
fn
test_cache_hit_rate_with_identical_requests
()
{
let
block_size
:
usize
=
64
;
let
max_output_tokens
:
usize
=
10
;
let
speedup_ratio
=
10.0
;
let
num_requests
=
10
;
let
token_length
=
65
;
// Create channel for token output
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
// Create scheduler args
let
args
=
MockEngineArgs
::
builder
()
.num_gpu_blocks
(
100
)
// Large enough to not be a constraint
.block_size
(
block_size
)
.speedup_ratio
(
speedup_ratio
)
.build
()
.unwrap
();
// Create scheduler
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
// Create identical tokens for all requests
let
identical_tokens
:
Vec
<
u32
>
=
(
0
..
token_length
)
.map
(|
i
|
i
as
u32
)
.collect
();
// Send all requests with identical tokens
for
_
in
0
..
num_requests
{
let
request
=
DirectRequest
{
tokens
:
identical_tokens
.clone
(),
max_output_tokens
,
uuid
:
None
,
dp_rank
:
None
,
};
scheduler
.receive
(
request
)
.await
;
// Sleep for 0.1 second after each request
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
}
// Collect all generated tokens
let
mut
received_tokens
=
0
;
// Set up a timeout that resets to 0.5 seconds on each received token
let
timeout
=
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
500
));
tokio
::
pin!
(
timeout
);
// Set up debug ticker interval
let
mut
debug_interval
=
interval
(
Duration
::
from_millis
(
500
));
loop
{
tokio
::
select!
{
biased
;
// Manual debug ticker that prints forward pass metrics
_
=
debug_interval
.tick
()
=>
{
let
_
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
println!
(
"Forward Pass Metrics: {_metrics:#?}"
);
}
Some
(
_
signal
)
=
output_rx
.recv
()
=>
{
received_tokens
+=
1
;
// Reset timeout whenever we receive a token
timeout
.set
(
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
500
)));
}
_
=
&
mut
timeout
=>
{
// Break when timeout occurs (no more tokens for 0.5 seconds)
break
;
}
}
}
// Verify forward pass metrics
let
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
assert_eq!
(
metrics
.num_requests_waiting
,
0
,
"Expected no waiting requests, got {}"
,
metrics
.num_requests_waiting
);
assert
!
(
metrics
.gpu_prefix_cache_hit_rate
>
0.8
,
"Expected cache hit rate > 0.8, got {}"
,
metrics
.gpu_prefix_cache_hit_rate
);
println!
(
"Test passed! Cache hit rate: {:.3}"
,
metrics
.gpu_prefix_cache_hit_rate
);
println!
(
"Received {received_tokens} tokens"
);
}
#[tokio::test]
async
fn
test_receiver_drop_cleans_up_resources
()
{
let
block_size
:
usize
=
64
;
let
input_tokens
=
256
;
let
max_output_tokens
=
200
;
// More than we'll receive
// Create channel for token output
let
(
output_tx
,
mut
output_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
// Create scheduler args
let
args
=
MockEngineArgs
::
builder
()
.num_gpu_blocks
(
10
)
// Enough for 256 tokens (4 blocks)
.block_size
(
block_size
)
.speedup_ratio
(
100.0
)
// Fast simulation
.build
()
.unwrap
();
// Create scheduler
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
// Create request with 256 tokens
let
tokens
:
Vec
<
u32
>
=
(
0
..
input_tokens
)
.map
(|
i
|
i
as
u32
)
.collect
();
let
request
=
DirectRequest
{
tokens
,
max_output_tokens
,
uuid
:
None
,
dp_rank
:
None
,
};
scheduler
.receive
(
request
)
.await
;
// Receive exactly 129 tokens
let
mut
received_count
=
0
;
while
received_count
<
129
{
if
let
Some
(
_
signal
)
=
output_rx
.recv
()
.await
{
received_count
+=
1
;
}
else
{
panic!
(
"Channel closed before receiving 129 tokens"
);
}
}
// Drop the receiver immediately
drop
(
output_rx
);
// Wait for 1 second to allow cleanup
tokio
::
time
::
sleep
(
Duration
::
from_secs
(
1
))
.await
;
// Check forward pass metrics
let
metrics
=
scheduler
.get_forward_pass_metrics
()
.await
;
assert_eq!
(
metrics
.gpu_cache_usage_perc
,
0.0
,
"Expected GPU cache usage to be 0%, got {}%"
,
metrics
.gpu_cache_usage_perc
*
100.0
);
assert_eq!
(
metrics
.kv_active_blocks
,
0
,
"Expected 0 active blocks, got {}"
,
metrics
.kv_active_blocks
);
}
}
lib/llm/src/mocker/sequence.rs
View file @
f0652d89
...
...
@@ -23,16 +23,23 @@ use uuid;
fn
create_unique_blocks_from_sequence
(
tokens
:
&
TokenBlockSequence
,
uuid
:
Option
<
uuid
::
Uuid
>
,
block_size
:
u32
,
block_size
:
usize
,
enable_prefix_caching
:
bool
,
)
->
Vec
<
UniqueBlock
>
{
let
mut
unique_blocks
:
Vec
<
UniqueBlock
>
=
tokens
.blocks
()
.iter
()
.map
(|
block
|
UniqueBlock
::
FullBlock
(
block
.sequence_hash
()))
.map
(|
block
|
{
if
enable_prefix_caching
{
UniqueBlock
::
FullBlock
(
block
.sequence_hash
())
}
else
{
UniqueBlock
::
FullBlock
(
random
::
<
u64
>
())
}
})
.collect
();
// Only push the partial block if tokens count isn't a multiple of block_size
if
tokens
.total_tokens
()
%
(
block_size
as
usize
)
!=
0
{
if
tokens
.total_tokens
()
%
block_size
!=
0
{
unique_blocks
.push
(
match
uuid
{
Some
(
uuid
)
=>
UniqueBlock
::
PartialBlock
(
uuid
),
None
=>
UniqueBlock
::
default
(),
...
...
@@ -50,10 +57,7 @@ pub struct ActiveSequence {
tokens
:
TokenBlockSequence
,
#[getter(copy)]
block_size
:
u32
,
#[getter(copy)]
chunk_size
:
usize
,
// TODO: not actually used
block_size
:
usize
,
#[getter(copy)]
max_output_tokens
:
usize
,
...
...
@@ -61,10 +65,16 @@ pub struct ActiveSequence {
#[getter(copy)]
generated_tokens
:
usize
,
#[getter(copy)]
already_generated_tokens
:
usize
,
#[getter(copy)]
num_input_tokens
:
usize
,
creation_signal
:
Option
<
MoveBlock
>
,
#[getter(copy)]
enable_prefix_caching
:
bool
,
}
impl
ActiveSequence
{
...
...
@@ -72,32 +82,33 @@ impl ActiveSequence {
pub
fn
new
(
tokens
:
Vec
<
u32
>
,
max_output_tokens
:
usize
,
block_size
:
Option
<
u
32
>
,
chunk_size
:
Option
<
usize
>
,
block_size
:
Option
<
u
size
>
,
enable_prefix_caching
:
bool
,
)
->
Self
{
let
block_size
=
block_size
.unwrap_or
(
64
);
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
let
chunk_size
=
chunk_size
.unwrap_or
(
256
);
let
num_input_tokens
=
tokens
.len
();
let
tokens
=
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
,
None
);
let
unique_blocks
=
create_unique_blocks_from_sequence
(
&
tokens
,
None
,
block_size
);
let
creation_signal
=
Some
(
MoveBlock
::
Use
(
unique_blocks
.clone
(),
None
));
let
tokens
=
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
let
unique_blocks
=
create_unique_blocks_from_sequence
(
&
tokens
,
None
,
block_size
,
enable_prefix_caching
);
let
creation_signal
=
Some
(
MoveBlock
::
Use
(
unique_blocks
.clone
()));
Self
{
unique_blocks
,
tokens
,
block_size
,
chunk_size
,
max_output_tokens
,
generated_tokens
:
0
,
already_generated_tokens
:
0
,
num_input_tokens
,
creation_signal
,
enable_prefix_caching
,
}
}
pub
fn
extra_tokens
(
&
self
)
->
u32
{
(
self
.len
()
%
self
.block_size
as
usize
)
as
u32
(
self
.len
()
%
self
.block_size
)
as
u32
}
pub
fn
len
(
&
self
)
->
usize
{
...
...
@@ -112,20 +123,31 @@ impl ActiveSequence {
pub
fn
new_with_signal
(
tokens
:
Vec
<
u32
>
,
max_output_tokens
:
usize
,
block_size
:
Option
<
u
32
>
,
chunk_size
:
Option
<
usize
>
,
block_size
:
Option
<
u
size
>
,
enable_prefix_caching
:
bool
,
)
->
(
Self
,
Option
<
MoveBlock
>
)
{
let
mut
sequence
=
Self
::
new
(
tokens
,
max_output_tokens
,
block_size
,
chunk_size
);
let
mut
sequence
=
Self
::
new
(
tokens
,
max_output_tokens
,
block_size
,
enable_prefix_caching
);
let
signal
=
sequence
.creation_signal
.take
();
(
sequence
,
signal
)
}
/// Get the parent hash from the second-to-last block if it exists and is a FullBlock
fn
get_parent_hash
(
&
self
)
->
Option
<
u64
>
{
if
self
.unique_blocks
.len
()
<
2
{
return
None
;
}
match
&
self
.unique_blocks
[
self
.unique_blocks
.len
()
-
2
]
{
UniqueBlock
::
FullBlock
(
hash
)
=>
Some
(
*
hash
),
_
=>
panic!
(
"Cannot have a partial block as parent"
),
}
}
/// Push a token to the sequence
pub
fn
push
(
&
mut
self
,
token
:
u32
)
->
Option
<
Vec
<
MoveBlock
>>
{
self
.tokens
.append
(
token
)
.expect
(
"Token push failed."
);
self
.generated_tokens
+=
1
;
if
self
.len
()
%
(
self
.block_size
as
usize
)
!=
1
{
if
self
.len
()
%
self
.block_size
!=
1
{
return
None
;
}
...
...
@@ -135,16 +157,24 @@ impl ActiveSequence {
// Replace last partial block with full block if it exists
if
let
Some
(
UniqueBlock
::
PartialBlock
(
uuid
))
=
self
.unique_blocks
.last
()
.cloned
()
{
let
last_block_hash
=
self
.tokens
.last_complete_block
()
.unwrap
()
.sequence_hash
();
let
last_block_hash
=
if
self
.enable_prefix_caching
{
self
.tokens
.last_complete_block
()
.unwrap
()
.sequence_hash
()
}
else
{
random
::
<
u64
>
()
};
self
.unique_blocks
.pop
();
self
.unique_blocks
.push
(
UniqueBlock
::
FullBlock
(
last_block_hash
));
signals
.push
(
MoveBlock
::
Promote
(
uuid
,
last_block_hash
));
signals
.push
(
MoveBlock
::
Promote
(
uuid
,
last_block_hash
,
self
.get_parent_hash
(),
));
}
let
new_partial_block
=
UniqueBlock
::
default
();
self
.unique_blocks
.push
(
new_partial_block
.clone
());
signals
.push
(
MoveBlock
::
Use
(
vec!
[
new_partial_block
]
,
None
));
signals
.push
(
MoveBlock
::
Use
(
vec!
[
new_partial_block
]));
Some
(
signals
)
}
...
...
@@ -204,15 +234,19 @@ impl ActiveSequence {
}
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// maintaining the uuid of the last partial block
pub
fn
reset_with_signal
(
&
mut
self
)
->
Vec
<
MoveBlock
>
{
let
free_signal
=
self
.free_signal
();
self
.tokens
.truncate
(
self
.num_input_tokens
)
.unwrap
();
self
.unique_blocks
=
create_unique_blocks_from_sequence
(
&
self
.tokens
,
None
,
self
.block_size
);
self
.unique_blocks
=
create_unique_blocks_from_sequence
(
&
self
.tokens
,
None
,
self
.block_size
,
self
.enable_prefix_caching
,
);
self
.already_generated_tokens
=
self
.generated_tokens
.max
(
self
.already_generated_tokens
);
self
.generated_tokens
=
0
;
self
.creation_signal
=
Some
(
MoveBlock
::
Use
(
self
.unique_blocks
.clone
()
,
None
));
self
.creation_signal
=
Some
(
MoveBlock
::
Use
(
self
.unique_blocks
.clone
()));
free_signal
}
...
...
@@ -223,7 +257,7 @@ impl ActiveSequence {
self
.generated_tokens
=
self
.generated_tokens
.saturating_sub
(
1
);
// Reverts to the last full block
if
self
.tokens
.total_tokens
()
%
(
self
.block_size
as
usize
)
==
0
{
if
self
.tokens
.total_tokens
()
%
self
.block_size
==
0
{
self
.unique_blocks
.pop
();
}
}
...
...
@@ -238,14 +272,14 @@ mod tests {
// Create a sequence with block size 16 initialized with tokens [0..15]
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
15
)
.collect
();
let
(
mut
seq1
,
signal1
)
=
ActiveSequence
::
new_with_signal
(
initial_tokens
,
100
,
Some
(
16
),
Some
(
256
)
);
ActiveSequence
::
new_with_signal
(
initial_tokens
,
100
,
Some
(
16
),
true
);
assert_eq!
(
seq1
.num_input_tokens
(),
15
);
assert_eq!
(
seq1
.len
(),
15
);
// Check that we got a Use signal
assert
!
(
signal1
.is_some
());
match
&
signal1
{
Some
(
MoveBlock
::
Use
(
blocks
,
_
))
=>
{
Some
(
MoveBlock
::
Use
(
blocks
))
=>
{
assert_eq!
(
blocks
.len
(),
1
);
}
_
=>
panic!
(
"Expected Use signal"
),
...
...
@@ -264,33 +298,31 @@ mod tests {
let
signal_16
=
signal_16
.unwrap
();
assert_eq!
(
signal_16
.len
(),
2
);
// First signal should be Promote for the previous block
match
&
signal_16
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
assert_eq!
(
*
parent_hash
,
None
);
}
_
=>
panic!
(
"Expected Promote signal as second signal"
),
}
// Second signal should be Use for new partial block
match
&
signal_16
[
1
]
{
MoveBlock
::
Use
(
blocks
,
_
)
=>
{
MoveBlock
::
Use
(
blocks
)
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
_
=>
panic!
(
"Expected Use signal as first signal"
),
}
// First signal should be Promote for the previous block
match
&
signal_16
[
0
]
{
MoveBlock
::
Promote
(
uuid
,
_
)
=>
{
// The uuid is generated dynamically, so we just check it exists
let
_
=
uuid
;
}
_
=>
panic!
(
"Expected Promote signal as second signal"
),
}
// Verify state after pushing tokens
assert_eq!
(
seq1
.unique_blocks
()
.len
(),
2
);
// One full block and one partial block
assert_eq!
(
seq1
.len
(),
17
);
assert_eq!
(
seq1
.len
()
%
(
seq1
.block_size
()
as
usize
)
,
1
);
assert_eq!
(
seq1
.len
()
%
seq1
.block_size
(),
1
);
// Create another sequence with block size 16 initialized with tokens [0..17]
let
extended_tokens
:
Vec
<
u32
>
=
(
0
..
16
)
.collect
();
let
(
mut
seq2
,
_
)
=
ActiveSequence
::
new_with_signal
(
extended_tokens
,
100
,
Some
(
16
),
Some
(
256
));
let
(
mut
seq2
,
_
)
=
ActiveSequence
::
new_with_signal
(
extended_tokens
,
100
,
Some
(
16
),
true
);
seq2
.push
(
16
);
seq2
.pop
();
seq2
.push
(
16
);
...
...
@@ -335,12 +367,12 @@ mod tests {
"seq2 should have exactly 3 blocks"
);
assert_eq!
(
seq1
.len
()
%
(
seq1
.block_size
()
as
usize
)
,
seq1
.len
()
%
seq1
.block_size
(),
1
,
"seq1 should have 1 partial token"
);
assert_eq!
(
seq2
.len
()
%
(
seq2
.block_size
()
as
usize
)
,
seq2
.len
()
%
seq2
.block_size
(),
1
,
"seq2 should have 1 partial token"
);
...
...
@@ -352,9 +384,38 @@ mod tests {
"First two blocks should be identical"
);
// Push tokens 34..47 to seq1
for
token
in
33
..
48
{
seq1
.push
(
token
);
}
// Push token 48 and get the signal - this completes the block and triggers signals
let
signal
=
seq1
.push
(
48
);
let
signal
=
signal
.unwrap
();
// Check that signal[0] is promote
match
&
signal
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
// Check that the parent_hash matches unique_blocks[1], which should be a full block
if
let
UniqueBlock
::
FullBlock
(
expected_hash
)
=
seq1
.unique_blocks
()[
1
]
{
assert_eq!
(
*
parent_hash
,
Some
(
expected_hash
),
"Parent hash should match unique_blocks[1]"
);
}
else
{
panic!
(
"unique_blocks[1] should be a full block"
);
}
}
_
=>
panic!
(
"Expected Promote signal as first signal"
),
}
// Reset seq1 and check that it equals the original clone
let
free_signals
=
seq1
.reset_with_signal
();
// 49 - 15 generated tokens
assert_eq!
(
seq1
.already_generated_tokens
,
34
);
// Verify the reset signals include proper cleanup events
assert
!
(
!
free_signals
.is_empty
());
}
...
...
@@ -363,13 +424,12 @@ mod tests {
fn
test_active_sequence_generate_signals
()
{
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
let
initial_tokens
:
Vec
<
u32
>
=
(
0
..
14
)
.collect
();
let
(
mut
seq
,
signal
)
=
ActiveSequence
::
new_with_signal
(
initial_tokens
,
5
,
Some
(
16
),
Some
(
256
));
let
(
mut
seq
,
signal
)
=
ActiveSequence
::
new_with_signal
(
initial_tokens
,
5
,
Some
(
16
),
true
);
// Initial signal - should have received a Use signal for the partial block
assert
!
(
signal
.is_some
());
match
signal
{
Some
(
MoveBlock
::
Use
(
blocks
,
_
))
=>
{
Some
(
MoveBlock
::
Use
(
blocks
))
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
...
...
@@ -385,25 +445,23 @@ mod tests {
let
signals_second
=
seq
.generate
();
assert_eq!
(
signals_second
.len
(),
2
);
// First signal should be Use for new partial block
// First signal should be Promote
match
&
signals_second
[
0
]
{
MoveBlock
::
Promote
(
_
,
_
,
parent_hash
)
=>
{
assert_eq!
(
*
parent_hash
,
None
);
}
_
=>
panic!
(
"Expected Promote signal as first signal after second token"
),
}
// Second signal should be Use for new partial block
match
&
signals_second
[
1
]
{
MoveBlock
::
Use
(
blocks
,
_
)
=>
{
MoveBlock
::
Use
(
blocks
)
=>
{
assert_eq!
(
blocks
.len
(),
1
);
assert
!
(
matches!
(
blocks
[
0
],
UniqueBlock
::
PartialBlock
(
_
)));
}
_
=>
panic!
(
"Expected Use signal as second signal after second token"
),
}
// Second signal should be Promote
match
&
signals_second
[
0
]
{
MoveBlock
::
Promote
(
uuid
,
hash
)
=>
{
// The uuid and hash values are generated dynamically, so we just check the event type
let
_
=
uuid
;
let
_
=
hash
;
}
_
=>
panic!
(
"Expected Promote signal as first signal after second token"
),
}
// Generate fourth token - should not trigger new signals as it's adding to partial block
let
signals_third
=
seq
.generate
();
assert_eq!
(
signals_third
.len
(),
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment