Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
c4106e6a
Commit
c4106e6a
authored
Apr 02, 2025
by
Ryan Olson
Committed by
GitHub
Apr 02, 2025
Browse files
feat: kv aware router executable (#399)
parent
183941fa
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
439 additions
and
228 deletions
+439
-228
Cargo.lock
Cargo.lock
+14
-0
components/metrics/src/lib.rs
components/metrics/src/lib.rs
+2
-2
components/router/Cargo.toml
components/router/Cargo.toml
+37
-0
components/router/src/main.rs
components/router/src/main.rs
+95
-0
dynamo.code-workspace
dynamo.code-workspace
+1
-3
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+7
-11
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+76
-55
lib/llm/src/kv_router/metrics_aggregator.rs
lib/llm/src/kv_router/metrics_aggregator.rs
+13
-39
lib/llm/src/kv_router/protocols.rs
lib/llm/src/kv_router/protocols.rs
+24
-0
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+142
-117
lib/llm/src/kv_router/scoring.rs
lib/llm/src/kv_router/scoring.rs
+4
-1
lib/llm/src/tokens.rs
lib/llm/src/tokens.rs
+7
-0
lib/runtime/src/lib.rs
lib/runtime/src/lib.rs
+1
-0
lib/runtime/src/prelude.rs
lib/runtime/src/prelude.rs
+16
-0
No files found.
Cargo.lock
View file @
c4106e6a
...
@@ -5166,6 +5166,20 @@ dependencies = [
...
@@ -5166,6 +5166,20 @@ dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.52.0",
]
]
[[package]]
name = "router"
version = "0.1.0"
dependencies = [
"clap",
"dynamo-llm",
"dynamo-runtime",
"rand 0.9.0",
"serde",
"serde_json",
"tokio",
"tracing",
]
[[package]]
[[package]]
name = "rstest"
name = "rstest"
version = "0.18.2"
version = "0.18.2"
...
...
components/metrics/src/lib.rs
View file @
c4106e6a
...
@@ -465,8 +465,8 @@ impl PrometheusMetrics {
...
@@ -465,8 +465,8 @@ impl PrometheusMetrics {
/// Update metrics with current values
/// Update metrics with current values
fn
update
(
&
self
,
config
:
&
LLMWorkerLoadCapacityConfig
,
processed
:
&
ProcessedEndpoints
)
{
fn
update
(
&
self
,
config
:
&
LLMWorkerLoadCapacityConfig
,
processed
:
&
ProcessedEndpoints
)
{
// Update per-worker metrics
// Update per-worker metrics
for
endpoint
in
processed
.endpoints
.iter
()
{
for
(
worker_id
,
endpoint
)
in
processed
.endpoints
.iter
()
{
let
worker_id
=
endpoint
.
worker_id
()
.to_string
();
let
worker_id
=
worker_id
.to_string
();
let
metrics
=
endpoint
.data
.clone
();
let
metrics
=
endpoint
.data
.clone
();
self
.set_worker_gauge
(
self
.set_worker_gauge
(
...
...
components/router/Cargo.toml
0 → 100644
View file @
c4106e6a
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name
=
"router"
version.workspace
=
true
edition.workspace
=
true
description.workspace
=
true
authors.workspace
=
true
license.workspace
=
true
homepage.workspace
=
true
repository.workspace
=
true
keywords.workspace
=
true
[dependencies]
dynamo-runtime
=
{
workspace
=
true
}
dynamo-llm
=
{
workspace
=
true
}
rand
=
{
workspace
=
true
}
serde
=
{
workspace
=
true
}
serde_json
=
{
workspace
=
true
}
tokio
=
{
workspace
=
true
}
tracing
=
{
workspace
=
true
}
clap
=
{
version
=
"4.5"
,
features
=
["derive"]
}
components/router/src/main.rs
0 → 100644
View file @
c4106e6a
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO(#400):
// Instead of passing in a block_size, we should get this data from the backend component's config.
// What changes need to be made:
// 1. Take as an argument the name of the backend component.
// 2. Update the backend component to produce a config in a standard location.
// 3. Update the KvRouter to read the config from the backend component.
use
clap
::
Parser
;
use
dynamo_llm
::
kv_router
::{
protocols
::
WorkerSelectionResult
,
scheduler
::{
DefaultWorkerSelector
,
KvSchedulerError
,
SchedulingRequest
},
scoring
::
ProcessedEndpoints
,
KvRouter
,
WorkerSelector
,
};
use
dynamo_runtime
::{
logging
,
pipeline
::
network
::
Ingress
,
DistributedRuntime
,
Result
,
Runtime
,
Worker
,
};
#[derive(Parser)]
#[command(author,
version,
about,
long_about
=
None)]
struct
Args
{
/// Namespace for the distributed component
#[arg(long)]
namespace
:
String
,
/// Component name for the service
#[arg(long,
default_value
=
"kv_aware_router"
)]
component
:
String
,
/// Block size for the router
#[arg(long)]
block_size
:
usize
,
}
fn
main
()
->
Result
<
()
>
{
logging
::
init
();
let
worker
=
Worker
::
from_settings
()
?
;
worker
.execute
(
app
)
}
async
fn
app
(
runtime
:
Runtime
)
->
Result
<
()
>
{
let
args
=
Args
::
parse
();
let
runtime
=
DistributedRuntime
::
from_settings
(
runtime
)
.await
?
;
let
component
=
runtime
.namespace
(
&
args
.namespace
)
?
.component
(
&
args
.component
)
?
;
let
selector
=
Box
::
new
(
CustomWorkerSelector
::
default
());
let
router
=
KvRouter
::
new
(
component
.clone
(),
args
.block_size
,
Some
(
selector
))
.await
?
;
let
router
=
Ingress
::
for_engine
(
router
)
?
;
component
.service_builder
()
.create
()
.await
?
.endpoint
(
"generate"
)
.endpoint_builder
()
.handler
(
router
)
.start
()
.await
}
#[derive(Default)]
pub
struct
CustomWorkerSelector
(
DefaultWorkerSelector
);
impl
WorkerSelector
for
CustomWorkerSelector
{
fn
select_worker
(
&
self
,
workers
:
&
ProcessedEndpoints
,
request
:
&
SchedulingRequest
,
block_size
:
usize
,
)
->
Result
<
WorkerSelectionResult
,
KvSchedulerError
>
{
// customize logic here
// F12 into [DefaultWorkerSelector] to see the original logic
self
.0
.select_worker
(
workers
,
request
,
block_size
)
}
}
dynamo.code-workspace
View file @
c4106e6a
...
@@ -6,10 +6,8 @@
...
@@ -6,10 +6,8 @@
],
],
"settings": {
"settings": {
"rust-analyzer.linkedProjects": [
"rust-analyzer.linkedProjects": [
"
components/metrics/
Cargo.toml",
"Cargo.toml",
"launch/dynamo-run/Cargo.toml",
"launch/dynamo-run/Cargo.toml",
"lib/llm/Cargo.toml",
"lib/runtime/Cargo.toml",
"lib/bindings/python/Cargo.toml"
"lib/bindings/python/Cargo.toml"
],
],
"rust-analyzer.procMacro.enable": true,
"rust-analyzer.procMacro.enable": true,
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
c4106e6a
...
@@ -30,17 +30,13 @@ pub(crate) struct KvRouter {
...
@@ -30,17 +30,13 @@ pub(crate) struct KvRouter {
#[pymethods]
#[pymethods]
impl
KvRouter
{
impl
KvRouter
{
#[new]
#[new]
// [FXIME] 'drt' can be obtained from 'component'
fn
new
(
component
:
Component
,
kv_block_size
:
usize
)
->
PyResult
<
Self
>
{
fn
new
(
drt
:
DistributedRuntime
,
component
:
Component
,
kv_block_size
:
usize
)
->
PyResult
<
Self
>
{
let
runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
let
runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
runtime
.block_on
(
async
{
runtime
.block_on
(
async
{
let
inner
=
llm_rs
::
kv_router
::
KvRouter
::
from_runtime
(
let
inner
=
drt
.inner
.clone
(),
llm_rs
::
kv_router
::
KvRouter
::
new
(
component
.inner
.clone
(),
kv_block_size
,
None
)
component
.inner
.clone
(),
.await
kv_block_size
,
.map_err
(
to_pyerr
)
?
;
)
.await
.map_err
(
to_pyerr
)
?
;
Ok
(
Self
{
inner
})
Ok
(
Self
{
inner
})
})
})
}
}
...
@@ -376,8 +372,8 @@ impl KvMetricsAggregator {
...
@@ -376,8 +372,8 @@ impl KvMetricsAggregator {
let
endpoint_kv_metrics
=
endpoints
let
endpoint_kv_metrics
=
endpoints
.endpoints
.endpoints
.iter
()
.iter
()
.map
(|
x
|
EndpointKvMetrics
{
.map
(|
(
worker_id
,
x
)
|
EndpointKvMetrics
{
worker_id
:
x
.
worker_id
()
,
worker_id
:
*
worker_id
,
request_active_slots
:
x
.data.request_active_slots
,
request_active_slots
:
x
.data.request_active_slots
,
request_total_slots
:
x
.data.request_total_slots
,
request_total_slots
:
x
.data.request_total_slots
,
kv_active_blocks
:
x
.data.kv_active_blocks
,
kv_active_blocks
:
x
.data.kv_active_blocks
,
...
...
lib/llm/src/kv_router.rs
View file @
c4106e6a
...
@@ -14,11 +14,17 @@
...
@@ -14,11 +14,17 @@
// limitations under the License.
// limitations under the License.
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
dynamo_runtime
::{
component
::
Component
,
component
::
Namespace
,
DistributedRuntime
};
use
dynamo_runtime
::{
use
futures
::
stream
::
StreamExt
;
component
::
Component
,
pipeline
::{
async_trait
,
AsyncEngine
,
AsyncEngineContextProvider
,
Error
,
ManyOut
,
ResponseStream
,
SingleIn
,
},
prelude
::
*
,
protocols
::
annotated
::
Annotated
,
};
use
futures
::
stream
::{
self
,
StreamExt
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tracing
;
pub
mod
indexer
;
pub
mod
indexer
;
pub
mod
metrics_aggregator
;
pub
mod
metrics_aggregator
;
...
@@ -27,14 +33,18 @@ pub mod publisher;
...
@@ -27,14 +33,18 @@ pub mod publisher;
pub
mod
scheduler
;
pub
mod
scheduler
;
pub
mod
scoring
;
pub
mod
scoring
;
use
crate
::
kv_router
::{
use
crate
::{
indexer
::{
KvIndexer
,
KvIndexerInterface
,
RouterEvent
},
kv_router
::{
metrics_aggregator
::
collect_endpoints_task
,
indexer
::{
KvIndexer
,
KvIndexerInterface
,
RouterEvent
},
scheduler
::
KvScheduler
,
metrics_aggregator
::
KvMetricsAggregator
,
scoring
::
ProcessedEndpoints
,
protocols
::{
LocalBlockHash
,
RouterRequest
,
RouterResponse
,
WorkerSelectionResult
},
scheduler
::{
KvScheduler
,
KvSchedulerError
,
SchedulingRequest
},
scoring
::
ProcessedEndpoints
,
},
tokens
::
Tokens
,
};
};
use
dynamo_runtime
::
traits
::
events
::
{
EventPublisher
,
EventSubscriber
}
;
use
dynamo_runtime
::
traits
::
events
::
EventSubscriber
;
// [gluo TODO] shouldn't need to be public
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
// this should be discovered from the component
...
@@ -42,49 +52,40 @@ pub const KV_EVENT_SUBJECT: &str = "kv_events";
...
@@ -42,49 +52,40 @@ pub const KV_EVENT_SUBJECT: &str = "kv_events";
pub
const
KV_HIT_RATE_SUBJECT
:
&
str
=
"kv-hit-rate"
;
pub
const
KV_HIT_RATE_SUBJECT
:
&
str
=
"kv-hit-rate"
;
pub
const
KV_METRICS_ENDPOINT
:
&
str
=
"load_metrics"
;
pub
const
KV_METRICS_ENDPOINT
:
&
str
=
"load_metrics"
;
pub
struct
KvRouter
{
/// A trait that users can implement to define custom selection logic
// properties of request plane
pub
trait
WorkerSelector
{
// maybe rolled up into the generic object or not
fn
select_worker
(
service_name
:
String
,
&
self
,
workers
:
&
ProcessedEndpoints
,
cancellation_token
:
CancellationToken
,
request
:
&
SchedulingRequest
,
block_size
:
usize
,
#[allow(dead_code)]
)
->
Result
<
WorkerSelectionResult
,
KvSchedulerError
>
;
scheduler
:
KvScheduler
,
}
pub
struct
KvRouter
{
indexer
:
KvIndexer
,
indexer
:
KvIndexer
,
scheduler
:
KvScheduler
,
block_size
:
usize
,
}
}
impl
KvRouter
{
impl
KvRouter
{
pub
async
fn
from_runtime
(
runtime
:
DistributedRuntime
,
component
:
Component
,
kv_block_size
:
usize
,
)
->
Result
<
Arc
<
Self
>>
{
let
namespace
=
runtime
.namespace
(
component
.namespace
()
.name
())
?
;
tracing
::
info!
(
"Component Namespace {}"
,
component
.namespace
());
tracing
::
info!
(
"Component Service Name {}"
,
component
.service_name
());
tracing
::
info!
(
"KV Subject {}.{}"
,
component
.subject
(),
KV_EVENT_SUBJECT
);
Self
::
new
(
component
,
namespace
,
kv_block_size
)
.await
}
pub
async
fn
new
(
pub
async
fn
new
(
component
:
Component
,
component
:
Component
,
namespace
:
Namespac
e
,
block_size
:
usiz
e
,
kv_block_size
:
usize
,
selector
:
Option
<
Box
<
dyn
WorkerSelector
+
Send
+
Sync
>>
,
)
->
Result
<
Arc
<
Self
>>
{
)
->
Result
<
Arc
<
Self
>>
{
let
cancellation_token
=
CancellationToken
::
new
();
let
cancellation_token
=
component
.drt
()
.primary_lease
()
.primary_token
();
let
(
ep_tx
,
ep_rx
)
=
tokio
::
sync
::
mpsc
::
channel
(
128
);
let
metrics_aggregator
=
tokio
::
spawn
(
collect_endpoints_task
(
KvMetricsAggregator
::
new
(
component
.clone
(),
cancellation_token
.clone
())
.await
;
component
.clone
(),
let
indexer
=
KvIndexer
::
new
(
cancellation_token
.clone
(),
block_size
);
ep_tx
,
let
scheduler
=
KvScheduler
::
start
(
cancellation_token
.clone
(),
component
.namespace
()
.clone
(),
));
block_size
,
metrics_aggregator
.endpoints_watcher
(),
let
indexer
=
KvIndexer
::
new
(
cancellation_token
.clone
(),
kv_block_size
);
selector
,
let
scheduler
=
KvScheduler
::
start
(
ep_rx
,
namespace
,
kv_block_size
)
.await
?
;
)
.await
?
;
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different.
// error checking below will be different.
...
@@ -112,21 +113,12 @@ impl KvRouter {
...
@@ -112,21 +113,12 @@ impl KvRouter {
});
});
Ok
(
Arc
::
new
(
Self
{
Ok
(
Arc
::
new
(
Self
{
service_name
:
component
.service_name
(),
cancellation_token
,
scheduler
,
scheduler
,
indexer
,
indexer
,
block_size
,
}))
}))
}
}
pub
fn
cancellation_token
(
&
self
)
->
CancellationToken
{
self
.cancellation_token
.clone
()
}
pub
fn
service_name
(
&
self
)
->
&
str
{
&
self
.service_name
}
// [TODO] indexer needs to take 'lora_id' as parameter
// [TODO] indexer needs to take 'lora_id' as parameter
pub
async
fn
schedule
(
&
self
,
token_ids
:
&
Vec
<
u32
>
,
_
lora_id
:
u64
)
->
Result
<
i64
>
{
pub
async
fn
schedule
(
&
self
,
token_ids
:
&
Vec
<
u32
>
,
_
lora_id
:
u64
)
->
Result
<
i64
>
{
// Extracting part of the code in KvRouter::generate() for only
// Extracting part of the code in KvRouter::generate() for only
...
@@ -141,3 +133,32 @@ impl KvRouter {
...
@@ -141,3 +133,32 @@ impl KvRouter {
Ok
(
worker_id
)
Ok
(
worker_id
)
}
}
}
}
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
RouterRequest
>
,
ManyOut
<
Annotated
<
RouterResponse
>>
,
Error
>
for
KvRouter
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
RouterRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
RouterResponse
>>>
{
let
(
request
,
ctx
)
=
request
.into_parts
();
let
isl_tokens
=
request
.tokens
.len
();
let
block_size
=
self
.block_size
;
// Compute the block hashes in a blocking task
let
local_block_hashes
:
Vec
<
LocalBlockHash
>
=
tokio
::
task
::
spawn_blocking
(
move
||
{
Tokens
::
compute_block_hash
(
&
request
.tokens
,
block_size
)
.into_iter
()
.map
(
LocalBlockHash
)
.collect
()
})
.await
?
;
let
overlap_scores
=
self
.indexer
.find_matches
(
local_block_hashes
)
.await
?
;
let
worker_id
=
self
.scheduler
.schedule
(
overlap_scores
,
isl_tokens
)
.await
?
;
let
response
=
RouterResponse
{
worker_id
};
let
response
=
Annotated
::
from_data
(
response
);
let
stream
=
stream
::
iter
(
vec!
[
response
]);
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
ctx
.context
()))
}
}
lib/llm/src/kv_router/metrics_aggregator.rs
View file @
c4106e6a
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
std
::
sync
::{
Arc
,
Mutex
};
pub
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
pub
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
crate
::
kv_router
::
KV_METRICS_ENDPOINT
;
use
crate
::
kv_router
::
KV_METRICS_ENDPOINT
;
...
@@ -22,61 +20,36 @@ use crate::kv_router::scheduler::Endpoint;
...
@@ -22,61 +20,36 @@ use crate::kv_router::scheduler::Endpoint;
use
crate
::
kv_router
::
ProcessedEndpoints
;
use
crate
::
kv_router
::
ProcessedEndpoints
;
use
dynamo_runtime
::
component
::
Component
;
use
dynamo_runtime
::
component
::
Component
;
use
dynamo_runtime
::{
service
::
EndpointInfo
,
utils
::
Duration
,
Result
};
use
dynamo_runtime
::{
service
::
EndpointInfo
,
utils
::
Duration
,
Result
};
use
tokio
::
sync
::
watch
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio_util
::
sync
::
CancellationToken
;
pub
struct
KvMetricsAggregator
{
pub
struct
KvMetricsAggregator
{
pub
service_name
:
String
,
pub
service_name
:
String
,
pub
endpoints
:
Arc
<
Mutex
<
ProcessedEndpoints
>
>
,
pub
endpoints
_rx
:
watch
::
Receiver
<
ProcessedEndpoints
>
,
}
}
impl
KvMetricsAggregator
{
impl
KvMetricsAggregator
{
pub
async
fn
new
(
component
:
Component
,
cancellation_token
:
CancellationToken
)
->
Self
{
pub
async
fn
new
(
component
:
Component
,
cancellation_token
:
CancellationToken
)
->
Self
{
let
(
ep_tx
,
mut
ep_rx
)
=
tokio
::
sync
::
mpsc
::
channel
(
128
);
let
(
watch_tx
,
watch_rx
)
=
watch
::
channel
(
ProcessedEndpoints
::
default
()
);
tokio
::
spawn
(
collect_endpoints_task
(
tokio
::
spawn
(
collect_endpoints_task
(
component
.clone
(),
component
.clone
(),
ep
_tx
,
watch
_tx
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
));
));
tracing
::
trace!
(
"awaiting the start of the background endpoint subscriber"
);
let
endpoints
=
Arc
::
new
(
Mutex
::
new
(
ProcessedEndpoints
::
default
()));
let
endpoints_clone
=
endpoints
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
debug!
(
"scheduler background task started"
);
loop
{
match
ep_rx
.recv
()
.await
{
Some
(
endpoints
)
=>
match
endpoints_clone
.lock
()
{
Ok
(
mut
shared_endpoint
)
=>
{
*
shared_endpoint
=
endpoints
;
}
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to acquire lock on endpoints: {:?}"
,
e
);
}
},
None
=>
{
tracing
::
warn!
(
"endpoint subscriber shutdown"
);
break
;
}
};
}
tracing
::
trace!
(
"background endpoint subscriber shutting down"
);
});
Self
{
Self
{
service_name
:
component
.service_name
(),
service_name
:
component
.service_name
(),
endpoints
,
endpoints
_rx
:
watch_rx
,
}
}
}
}
pub
fn
get_endpoints
(
&
self
)
->
ProcessedEndpoints
{
pub
fn
get_endpoints
(
&
self
)
->
ProcessedEndpoints
{
match
self
.endpoints
.lock
()
{
self
.endpoints_rx
.borrow
()
.clone
()
Ok
(
endpoints
)
=>
endpoints
.clone
(),
}
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to acquire lock on endpoints: {:?}"
,
e
);
pub
fn
endpoints_watcher
(
&
self
)
->
watch
::
Receiver
<
ProcessedEndpoints
>
{
ProcessedEndpoints
::
default
()
self
.endpoints_rx
.clone
()
}
}
}
}
}
}
...
@@ -108,7 +81,7 @@ pub async fn collect_endpoints(
...
@@ -108,7 +81,7 @@ pub async fn collect_endpoints(
pub
async
fn
collect_endpoints_task
(
pub
async
fn
collect_endpoints_task
(
component
:
Component
,
component
:
Component
,
ep_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
ProcessedEndpoints
>
,
watch_tx
:
watch
::
Sender
<
ProcessedEndpoints
>
,
cancel
:
CancellationToken
,
cancel
:
CancellationToken
,
)
{
)
{
let
backoff_delay
=
Duration
::
from_millis
(
100
);
let
backoff_delay
=
Duration
::
from_millis
(
100
);
...
@@ -161,7 +134,8 @@ pub async fn collect_endpoints_task(
...
@@ -161,7 +134,8 @@ pub async fn collect_endpoints_task(
);
);
let
processed
=
ProcessedEndpoints
::
new
(
endpoints
);
let
processed
=
ProcessedEndpoints
::
new
(
endpoints
);
if
ep_tx
.send
(
processed
)
.await
.is_err
()
{
if
watch_tx
.send
(
processed
)
.is_err
()
{
tracing
::
trace!
(
"failed to send processed endpoints; shutting down"
);
tracing
::
trace!
(
"failed to send processed endpoints; shutting down"
);
break
;
break
;
}
}
...
...
lib/llm/src/kv_router/protocols.rs
View file @
c4106e6a
...
@@ -13,8 +13,32 @@
...
@@ -13,8 +13,32 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
crate
::
tokens
::
Token
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Default)]
pub
struct
RouterRequest
{
pub
tokens
:
Vec
<
Token
>
,
}
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Default)]
pub
struct
RouterResponse
{
pub
worker_id
:
i64
,
}
#[derive(Debug)]
pub
struct
WorkerSelectionResult
{
/// The worker id of the selected worker
pub
worker_id
:
i64
,
/// The total number of blocks required to prefill the request
pub
required_blocks
:
u64
,
/// The number of blocks that the selected worker may already have cached.
/// This is not a guarantee, but an estimate.
pub
overlap_blocks
:
usize
,
}
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Default)]
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Default)]
pub
struct
ForwardPassMetrics
{
pub
struct
ForwardPassMetrics
{
pub
request_active_slots
:
u64
,
pub
request_active_slots
:
u64
,
...
...
lib/llm/src/kv_router/scheduler.rs
View file @
c4106e6a
...
@@ -15,15 +15,19 @@
...
@@ -15,15 +15,19 @@
use
dynamo_runtime
::
component
::
Namespace
;
use
dynamo_runtime
::
component
::
Namespace
;
use
dynamo_runtime
::
traits
::
events
::
EventPublisher
;
use
dynamo_runtime
::
traits
::
events
::
EventPublisher
;
use
rand
::
Rng
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
borrow
::
BorrowMut
;
use
std
::
borrow
::
BorrowMut
;
use
std
::
c
mp
::
min
;
use
std
::
c
ollections
::
HashMap
;
use
crate
::
kv_router
::
indexer
::
OverlapScores
;
use
crate
::
kv_router
::
indexer
::
OverlapScores
;
pub
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
pub
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
crate
::
kv_router
::
scoring
::
ProcessedEndpoints
;
use
crate
::
kv_router
::
scoring
::
ProcessedEndpoints
;
use
crate
::
kv_router
::
KV_HIT_RATE_SUBJECT
;
use
crate
::
kv_router
::
KV_HIT_RATE_SUBJECT
;
use
super
::
protocols
::
WorkerSelectionResult
;
use
super
::
WorkerSelector
;
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
KVHitRateEvent
{
pub
struct
KVHitRateEvent
{
pub
worker_id
:
i64
,
pub
worker_id
:
i64
,
...
@@ -68,8 +72,8 @@ impl Endpoint {
...
@@ -68,8 +72,8 @@ impl Endpoint {
}
}
pub
struct
SchedulingRequest
{
pub
struct
SchedulingRequest
{
isl_tokens
:
usize
,
pub
isl_tokens
:
usize
,
overlap
:
OverlapScores
,
pub
overlap
:
OverlapScores
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
i64
>
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
i64
>
,
}
}
...
@@ -87,24 +91,16 @@ pub struct KvScheduler {
...
@@ -87,24 +91,16 @@ pub struct KvScheduler {
impl
KvScheduler
{
impl
KvScheduler
{
pub
async
fn
start
(
pub
async
fn
start
(
endpoints_rx
:
tokio
::
sync
::
mpsc
::
Receiver
<
ProcessedEndpoints
>
,
ns
:
Namespace
,
ns
:
Namespace
,
kv_block_size
:
usize
,
block_size
:
usize
,
endpoints_rx
:
tokio
::
sync
::
watch
::
Receiver
<
ProcessedEndpoints
>
,
selector
:
Option
<
Box
<
dyn
WorkerSelector
+
Send
+
Sync
>>
,
)
->
Result
<
Self
,
KvSchedulerError
>
{
)
->
Result
<
Self
,
KvSchedulerError
>
{
let
selector
=
selector
.unwrap_or
(
Box
::
new
(
DefaultWorkerSelector
));
let
mut
endpoints_rx
=
endpoints_rx
;
let
mut
endpoints_rx
=
endpoints_rx
;
let
mut
endpoints
:
ProcessedEndpoints
=
endpoints_rx
.borrow_and_update
()
.clone
();
tracing
::
trace!
(
"awaiting the start of the background endpoint subscriber"
);
let
mut
endpoints
=
match
endpoints_rx
.recv
()
.await
{
Some
(
endpoints
)
=>
endpoints
,
None
=>
{
return
Err
(
KvSchedulerError
::
SubscriberShutdown
);
}
};
// Channel to asynchronously publish metric events on
let
(
event_tx
,
event_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
::
<
KVHitRateEvent
>
();
let
(
event_tx
,
event_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
::
<
KVHitRateEvent
>
();
// Publisher task
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
let
mut
event_rx
=
event_rx
;
let
mut
event_rx
=
event_rx
;
while
let
Some
(
event
)
=
event_rx
.recv
()
.await
{
while
let
Some
(
event
)
=
event_rx
.recv
()
.await
{
...
@@ -115,7 +111,7 @@ impl KvScheduler {
...
@@ -115,7 +111,7 @@ impl KvScheduler {
});
});
// Channel to accept new scheduling requests
// Channel to accept new scheduling requests
let
(
request_tx
,
request_rx
)
=
tokio
::
sync
::
mpsc
::
channel
::
<
SchedulingRequest
>
(
1
6
);
let
(
request_tx
,
request_rx
)
=
tokio
::
sync
::
mpsc
::
channel
::
<
SchedulingRequest
>
(
1
024
);
tracing
::
debug!
(
"scheduler starting"
);
tracing
::
debug!
(
"scheduler starting"
);
// Background task to handle scheduling requests
// Background task to handle scheduling requests
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
...
@@ -140,37 +136,33 @@ impl KvScheduler {
...
@@ -140,37 +136,33 @@ impl KvScheduler {
}
}
}
}
new_endpoints
=
endpoints_rx
.recv
()
=>
{
_
=
endpoints_rx
.changed
()
=>
{
match
new_endpoints
{
endpoints
=
endpoints_rx
.borrow_and_update
()
.clone
();
Some
(
new_endpoints
)
=>
{
continue
'outer
;
tracing
::
trace!
(
"updated endpoints"
);
endpoints
=
new_endpoints
;
continue
'outer
;
}
None
=>
{
tracing
::
trace!
(
"endpoint subscriber shutdown"
);
break
'outer
;
}
}
}
}
};
};
tracing
::
debug!
(
"selected"
);
tracing
::
debug!
(
"selected"
);
loop
{
loop
{
match
select_worker
(
endpoints
.borrow_mut
(),
&
request
,
&
event_tx
,
kv_block_size
)
match
selector
.select_worker
(
&
endpoints
,
&
request
,
block_size
)
{
{
Ok
(
selection
)
=>
{
Ok
(
worker_id
)
=>
{
let
worker_id
=
process_worker_selection
(
endpoints
.borrow_mut
(),
selection
,
&
event_tx
,
);
request
.respond
(
worker_id
);
request
.respond
(
worker_id
);
continue
'outer
;
continue
'outer
;
}
}
Err
(
KvSchedulerError
::
AllWorkersBusy
)
=>
{
Err
(
KvSchedulerError
::
AllWorkersBusy
)
=>
{
tracing
::
trace!
(
"all workers busy; waiting for more capacity"
);
tracing
::
trace!
(
"all workers busy; waiting for more capacity"
);
endpoints
=
match
endpoints_rx
.
recv
()
.await
{
match
endpoints_rx
.
changed
()
.await
{
Some
(
endpoints
)
=>
endpoints
,
Ok
(
_
)
=>
{}
None
=>
{
Err
(
e
)
=>
{
tracing
::
trace!
(
"endpoint subscriber shutdown"
);
tracing
::
error!
(
"error waiting for endpoints change: {:?}"
,
e
);
break
'outer
;
break
'outer
;
}
}
};
};
endpoints
=
endpoints_rx
.borrow_and_update
()
.clone
();
}
}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
tracing
::
error!
(
"error scheduling request: {:?}"
,
e
);
tracing
::
error!
(
"error scheduling request: {:?}"
,
e
);
...
@@ -212,104 +204,137 @@ impl KvScheduler {
...
@@ -212,104 +204,137 @@ impl KvScheduler {
}
}
}
}
pub
fn
select_worker
(
// This becomes the driver function that handles the selection result
pub
fn
process_worker_selection
(
workers
:
&
mut
ProcessedEndpoints
,
workers
:
&
mut
ProcessedEndpoints
,
request
:
&
SchedulingReques
t
,
selection
:
WorkerSelectionResul
t
,
event_tx
:
&
tokio
::
sync
::
mpsc
::
UnboundedSender
<
KVHitRateEvent
>
,
event_tx
:
&
tokio
::
sync
::
mpsc
::
UnboundedSender
<
KVHitRateEvent
>
,
kv_block_size
:
usize
,
)
->
i64
{
)
->
Result
<
i64
,
KvSchedulerError
>
{
let
worker
=
workers
// balance mode prioritizes balancing load across workers
.endpoints
let
balance_threshold
:
f64
=
0.1
;
.get_mut
(
&
selection
.worker_id
)
let
balance_mode
=
workers
.load_std
>
balance_threshold
*
workers
.load_avg
;
.expect
(
"worker not found"
);
// Determine alpha based on mode
// Update worker state
let
alpha
=
if
balance_mode
{
0.7
}
else
{
0.3
};
worker
.data.request_active_slots
+=
1
;
let
gamma
=
0.1
;
// example tuning param
worker
.data.kv_active_blocks
+=
selection
.required_blocks
-
selection
.overlap_blocks
as
u64
;
// Compute each worker's score
// Emit event
let
mut
best_index
=
None
;
if
let
Err
(
e
)
=
event_tx
.send
(
KVHitRateEvent
{
let
mut
best_cost
=
f64
::
INFINITY
;
worker_id
:
selection
.worker_id
,
// [FIXME] REMOVE ONLY FOR TESTING
isl_blocks
:
selection
.required_blocks
as
usize
,
if
workers
.endpoints
.is_empty
()
{
overlap_blocks
:
selection
.overlap_blocks
,
return
Err
(
KvSchedulerError
::
NoEndpoints
);
})
{
tracing
::
warn!
(
"Failed to send KV hit rate event: {:?}"
,
e
);
}
}
for
(
i
,
w
)
in
workers
.endpoints
.iter
()
.enumerate
()
{
selection
.worker_id
// Exclude workers that are at capacity
}
if
w
.data.request_active_slots
>=
w
.data.request_total_slots
||
w
.data.kv_active_blocks
>=
w
.data.kv_total_blocks
// Default implementation matching the Python _cost_function
{
#[derive(Default)]
continue
;
pub
struct
DefaultWorkerSelector
;
impl
WorkerSelector
for
DefaultWorkerSelector
{
fn
select_worker
(
&
self
,
workers
:
&
ProcessedEndpoints
,
request
:
&
SchedulingRequest
,
block_size
:
usize
,
)
->
Result
<
WorkerSelectionResult
,
KvSchedulerError
>
{
assert
!
(
request
.isl_tokens
>
0
);
let
mut
worker_scores
=
HashMap
::
new
();
let
mut
max_active
=
0.0
;
// Calculate worker scores and find max waiting requests
for
(
worker_id
,
ep
)
in
workers
.endpoints
.iter
()
{
// Calculate score similar to Python version
if
let
Some
(
score
)
=
request
.overlap.scores
.get
(
worker_id
)
{
let
score
=
*
score
as
f64
*
block_size
as
f64
/
request
.isl_tokens
as
f64
;
worker_scores
.insert
(
worker_id
,
score
);
}
// Track max waiting requests
max_active
=
f64
::
max
(
max_active
,
ep
.data.request_active_slots
as
f64
);
}
}
let
kv_load_ratio
=
w
.data.kv_active_blocks
as
f64
/
w
.data.kv_total_blocks
as
f64
;
if
max_active
==
0.0
{
let
load_deviation
=
kv_load_ratio
-
workers
.load_avg
;
return
Err
(
KvSchedulerError
::
NoEndpoints
);
}
let
worker_id
=
w
.worker_id
();
// make immutable
let
overlap
_score
=
request
.overlap.scores
.get
(
&
worker_id
)
.map_or
(
0
,
|
x
|
*
x
)
;
let
worker
_score
s
=
worker_scores
;
let
overlap_score
=
overlap_score
as
usize
*
kv_block_siz
e
;
let
max_active
=
max_activ
e
;
let
new_tokens
=
request
.isl_tokens
.saturating_sub
(
overlap_score
);
// Calculate logits for each worker
let
normalized_new_tokens
=
new_tokens
as
f64
/
request
.isl_tokens
as
f64
;
let
mut
best_logit
=
f64
::
NEG_INFINITY
;
let
mut
best_workers
=
Vec
::
new
();
let
request_load_ratio
=
for
(
worker_id
,
ep
)
in
workers
.endpoints
.iter
()
{
w
.data.request_active_slots
as
f64
/
w
.data.request_total_slots
as
f64
;
let
worker_id
=
*
worker_id
;
// cost = alpha * load_deviation + (1 - alpha)*normalized_new_tokens + gamma * request_load_ratio
// Get score or default to 0.0
let
cost
=
alpha
*
load_deviation
let
score
=
worker_scores
.get
(
&
worker_id
)
.copied
()
.unwrap_or
(
0.0
);
+
(
1.0
-
alpha
)
*
normalized_new_tokens
+
gamma
*
request_load_ratio
;
tracing
::
debug!
(
"worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}"
,
// Calculate normalized metrics
assert
!
(
ep
.data.kv_total_blocks
>
0
);
let
gpu_cache_usage
=
ep
.data.kv_active_blocks
as
f64
/
ep
.data.kv_total_blocks
as
f64
;
let
normalized_active
=
if
max_active
>
0.0
{
ep
.data.request_active_slots
as
f64
/
max_active
}
else
{
0.0
};
// Calculate logit using same formula as Python
let
logit
=
2.0
*
score
-
gpu_cache_usage
-
normalized_active
;
tracing
::
info!
(
"Formula for {}: {:.3} = 2.0 * {:.3} - {:.3} - {:.3}"
,
worker_id
,
worker_id
,
lo
ad_deviation
,
lo
git
,
normalized_new_tokens
,
score
,
request_load_ratio
,
gpu_cache_usage
,
cost
normalized_active
);
);
if
cost
<
best_cost
{
// Track best workers
best_cost
=
cost
;
match
logit
.partial_cmp
(
&
best_logit
)
{
best_index
=
Some
(
i
);
Some
(
std
::
cmp
::
Ordering
::
Greater
)
=>
{
best_logit
=
logit
;
best_workers
.clear
();
best_workers
.push
(
worker_id
);
}
Some
(
std
::
cmp
::
Ordering
::
Equal
)
=>
{
best_workers
.push
(
worker_id
);
}
_
=>
{}
}
}
}
}
if
let
Some
(
best_index
)
=
best_index
{
// Return early if no valid workers found
let
total_blocks
=
min
(
request
.isl_tokens
/
kv_block_size
,
1
);
if
best_workers
.is_empty
()
||
best_logit
==
0.0
{
return
Err
(
KvSchedulerError
::
NoEndpoints
);
workers
.endpoints
[
best_index
]
.data.request_active_slots
+=
1
;
workers
.endpoints
[
best_index
]
.data.kv_active_blocks
+=
total_blocks
as
u64
;
// Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
let
best_worker_id
=
workers
.endpoints
[
best_index
]
.worker_id
();
let
isl_blocks
=
request
.isl_tokens
/
kv_block_size
;
let
overlap_blocks
=
request
.overlap
.scores
.get
(
&
best_worker_id
)
.copied
()
.unwrap_or
(
0
);
if
let
Err
(
e
)
=
event_tx
.send
(
KVHitRateEvent
{
worker_id
:
best_worker_id
,
isl_blocks
,
overlap_blocks
:
overlap_blocks
as
usize
,
})
{
tracing
::
warn!
(
"Failed to send KV hit rate event: {:?}"
,
e
);
}
}
}
match
best_index
{
let
worker_id
=
if
best_workers
.len
()
==
1
{
Some
(
i
)
=>
{
best_workers
[
0
]
tracing
::
info!
(
}
else
{
"selected worker: {}; cost: {}"
,
// Randomly select from best workers
workers
.endpoints
[
i
]
.worker_id
(),
let
mut
rng
=
rand
::
rng
();
best_cost
best_workers
[
rng
.random_range
(
0
..
best_workers
.len
())]
);
};
Ok
(
workers
.endpoints
[
i
]
.worker_id
())
}
// Log selection metrics
None
=>
{
tracing
::
info!
(
"Selected worker: {}, logit: {:.3}"
,
worker_id
,
best_logit
);
tracing
::
debug!
(
"all workers busy"
);
Err
(
KvSchedulerError
::
AllWorkersBusy
)
let
total_blocks
=
std
::
cmp
::
min
(
request
.isl_tokens
/
block_size
,
1
)
as
u64
;
}
let
overlap_blocks
=
request
.overlap.scores
.get
(
&
worker_id
)
.copied
()
.unwrap_or
(
0
)
as
usize
;
Ok
(
WorkerSelectionResult
{
worker_id
,
required_blocks
:
total_blocks
,
overlap_blocks
,
})
}
}
}
}
lib/llm/src/kv_router/scoring.rs
View file @
c4106e6a
...
@@ -16,12 +16,13 @@
...
@@ -16,12 +16,13 @@
//! Scoring functions for the KV router.
//! Scoring functions for the KV router.
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
collections
::
HashMap
;
use
crate
::
kv_router
::
scheduler
::
Endpoint
;
use
crate
::
kv_router
::
scheduler
::
Endpoint
;
#[derive(Debug,
Default,
Serialize,
Deserialize,
Clone)]
#[derive(Debug,
Default,
Serialize,
Deserialize,
Clone)]
pub
struct
ProcessedEndpoints
{
pub
struct
ProcessedEndpoints
{
pub
endpoints
:
Vec
<
Endpoint
>
,
pub
endpoints
:
HashMap
<
i64
,
Endpoint
>
,
pub
load_avg
:
f64
,
pub
load_avg
:
f64
,
pub
load_std
:
f64
,
pub
load_std
:
f64
,
}
}
...
@@ -41,6 +42,8 @@ impl ProcessedEndpoints {
...
@@ -41,6 +42,8 @@ impl ProcessedEndpoints {
/
load_values
.len
()
as
f64
;
/
load_values
.len
()
as
f64
;
let
load_std
=
variance
.sqrt
();
let
load_std
=
variance
.sqrt
();
let
endpoints
=
endpoints
.into_iter
()
.map
(|
e
|
(
e
.worker_id
(),
e
))
.collect
();
ProcessedEndpoints
{
ProcessedEndpoints
{
endpoints
,
endpoints
,
load_avg
,
load_avg
,
...
...
lib/llm/src/tokens.rs
View file @
c4106e6a
...
@@ -105,6 +105,13 @@ impl Tokens {
...
@@ -105,6 +105,13 @@ impl Tokens {
pub
fn
into_sequence
(
self
,
block_size
:
usize
)
->
TokenSequence
{
pub
fn
into_sequence
(
self
,
block_size
:
usize
)
->
TokenSequence
{
TokenSequence
::
new
(
self
,
block_size
)
TokenSequence
::
new
(
self
,
block_size
)
}
}
pub
fn
compute_block_hash
(
tokens
:
&
[
Token
],
block_size
:
usize
)
->
Vec
<
BlockHash
>
{
tokens
.par_chunks_exact
(
block_size
)
.map
(|
chunk
|
compute_hash
(
cast_slice
(
chunk
)))
.collect
()
}
}
}
pub
struct
PartialTokenBlock
{
pub
struct
PartialTokenBlock
{
...
...
lib/runtime/src/lib.rs
View file @
c4106e6a
...
@@ -34,6 +34,7 @@ pub mod discovery;
...
@@ -34,6 +34,7 @@ pub mod discovery;
pub
mod
engine
;
pub
mod
engine
;
pub
mod
logging
;
pub
mod
logging
;
pub
mod
pipeline
;
pub
mod
pipeline
;
pub
mod
prelude
;
pub
mod
protocols
;
pub
mod
protocols
;
pub
mod
runnable
;
pub
mod
runnable
;
pub
mod
runtime
;
pub
mod
runtime
;
...
...
lib/runtime/src/prelude.rs
0 → 100644
View file @
c4106e6a
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub
use
crate
::
traits
::
*
;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment