Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
7d604dd3
Unverified
Commit
7d604dd3
authored
Nov 12, 2025
by
Yan Ru Pei
Committed by
GitHub
Nov 12, 2025
Browse files
feat: tie break on tree size when routing (#4257)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
f8219b12
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
103 additions
and
25 deletions
+103
-25
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+1
-0
lib/llm/src/kv_router/indexer.rs
lib/llm/src/kv_router/indexer.rs
+14
-0
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+88
-25
No files found.
lib/llm/src/kv_router.rs
View file @
7d604dd3
...
@@ -182,6 +182,7 @@ impl Indexer {
...
@@ -182,6 +182,7 @@ impl Indexer {
Indexer
::
None
=>
Ok
(
OverlapScores
{
Indexer
::
None
=>
Ok
(
OverlapScores
{
scores
:
HashMap
::
new
(),
scores
:
HashMap
::
new
(),
frequencies
:
Vec
::
new
(),
frequencies
:
Vec
::
new
(),
tree_sizes
:
HashMap
::
new
(),
}),
}),
}
}
}
}
...
...
lib/llm/src/kv_router/indexer.rs
View file @
7d604dd3
...
@@ -326,6 +326,16 @@ impl RadixTree {
...
@@ -326,6 +326,16 @@ impl RadixTree {
tracing
::
trace!
(
"RadixTree::find_matches: final scores={:?}"
,
scores
.scores
);
tracing
::
trace!
(
"RadixTree::find_matches: final scores={:?}"
,
scores
.scores
);
// Populate tree sizes for all workers that have scores
for
worker
in
scores
.scores
.keys
()
{
let
tree_size
=
self
.lookup
.get
(
worker
)
.expect
(
"worker in scores must exist in lookup table"
)
.len
();
scores
.tree_sizes
.insert
(
*
worker
,
tree_size
);
}
scores
scores
}
}
...
@@ -680,6 +690,8 @@ pub struct OverlapScores {
...
@@ -680,6 +690,8 @@ pub struct OverlapScores {
pub
scores
:
HashMap
<
WorkerWithDpRank
,
u32
>
,
pub
scores
:
HashMap
<
WorkerWithDpRank
,
u32
>
,
// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
// List of frequencies that the blocks have been accessed. Entries with value 0 are omitted.
pub
frequencies
:
Vec
<
usize
>
,
pub
frequencies
:
Vec
<
usize
>
,
// Map of worker to their tree size (number of blocks in the tree for that worker)
pub
tree_sizes
:
HashMap
<
WorkerWithDpRank
,
usize
>
,
}
}
impl
Default
for
OverlapScores
{
impl
Default
for
OverlapScores
{
...
@@ -698,6 +710,7 @@ impl OverlapScores {
...
@@ -698,6 +710,7 @@ impl OverlapScores {
Self
{
Self
{
scores
:
HashMap
::
new
(),
scores
:
HashMap
::
new
(),
frequencies
:
Vec
::
with_capacity
(
32
),
frequencies
:
Vec
::
with_capacity
(
32
),
tree_sizes
:
HashMap
::
new
(),
}
}
}
}
...
@@ -1225,6 +1238,7 @@ impl KvIndexerInterface for KvIndexerSharded {
...
@@ -1225,6 +1238,7 @@ impl KvIndexerInterface for KvIndexerSharded {
match
match_rx
.recv
()
.await
{
match
match_rx
.recv
()
.await
{
Some
(
response
)
=>
{
Some
(
response
)
=>
{
scores
.scores
.extend
(
response
.scores
);
scores
.scores
.extend
(
response
.scores
);
scores
.tree_sizes
.extend
(
response
.tree_sizes
);
if
response_num
==
0
{
if
response_num
==
0
{
scores
.frequencies
=
response
.frequencies
;
scores
.frequencies
=
response
.frequencies
;
...
...
lib/llm/src/kv_router/scheduler.rs
View file @
7d604dd3
...
@@ -386,12 +386,16 @@ impl KvScheduler {
...
@@ -386,12 +386,16 @@ impl KvScheduler {
}
}
// Helper function for softmax sampling
// Helper function for softmax sampling
fn
softmax_sample
(
logits
:
&
HashMap
<
WorkerWithDpRank
,
f64
>
,
temperature
:
f64
)
->
WorkerWithDpRank
{
// Returns a vec of workers: multiple if tied, single if sampled
fn
softmax_sample
(
logits
:
&
HashMap
<
WorkerWithDpRank
,
f64
>
,
temperature
:
f64
,
)
->
Vec
<
WorkerWithDpRank
>
{
if
logits
.is_empty
()
{
if
logits
.is_empty
()
{
panic!
(
"Empty logits for softmax sampling"
);
panic!
(
"Empty logits for softmax sampling"
);
}
}
// Guard: if temperature is 0, return
the
key with the smallest logit value
// Guard: if temperature is 0, return
all
key
s
with the smallest logit value
(ties)
if
temperature
==
0.0
{
if
temperature
==
0.0
{
// Find the minimum logit value
// Find the minimum logit value
let
min_logit
=
logits
.values
()
.fold
(
f64
::
INFINITY
,
|
a
,
&
b
|
a
.min
(
b
));
let
min_logit
=
logits
.values
()
.fold
(
f64
::
INFINITY
,
|
a
,
&
b
|
a
.min
(
b
));
...
@@ -403,10 +407,7 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) ->
...
@@ -403,10 +407,7 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) ->
.map
(|(
k
,
_
)|
*
k
)
.map
(|(
k
,
_
)|
*
k
)
.collect
();
.collect
();
// Randomly select from the minimum keys (handles single key case naturally)
return
min_keys
;
let
mut
rng
=
rand
::
rng
();
let
index
=
rng
.random_range
(
0
..
min_keys
.len
());
return
min_keys
[
index
];
}
}
let
keys
:
Vec
<
_
>
=
logits
.keys
()
.copied
()
.collect
();
let
keys
:
Vec
<
_
>
=
logits
.keys
()
.copied
()
.collect
();
...
@@ -449,12 +450,12 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) ->
...
@@ -449,12 +450,12 @@ fn softmax_sample(logits: &HashMap<WorkerWithDpRank, f64>, temperature: f64) ->
for
(
i
,
&
prob
)
in
probabilities
.iter
()
.enumerate
()
{
for
(
i
,
&
prob
)
in
probabilities
.iter
()
.enumerate
()
{
cumsum
+=
prob
;
cumsum
+=
prob
;
if
sample
<=
cumsum
{
if
sample
<=
cumsum
{
return
keys
[
i
];
return
vec!
[
keys
[
i
]
]
;
}
}
}
}
// Fallback to last key (shouldn't normally reach here)
// Fallback to last key (shouldn't normally reach here)
keys
[
keys
.len
()
-
1
]
vec!
[
keys
[
keys
.len
()
-
1
]
]
}
}
// Default implementation matching the Python _cost_function
// Default implementation matching the Python _cost_function
...
@@ -542,14 +543,34 @@ impl WorkerSelector for DefaultWorkerSelector {
...
@@ -542,14 +543,34 @@ impl WorkerSelector for DefaultWorkerSelector {
}
}
}
}
// Use softmax sampling to select worker
// Use softmax sampling to select worker
(s)
// Use override if provided, otherwise use default config
// Use override if provided, otherwise use default config
let
temperature
=
request
let
temperature
=
request
.router_config_override
.router_config_override
.as_ref
()
.as_ref
()
.and_then
(|
cfg
|
cfg
.router_temperature
)
.and_then
(|
cfg
|
cfg
.router_temperature
)
.unwrap_or
(
self
.kv_router_config.router_temperature
);
.unwrap_or
(
self
.kv_router_config.router_temperature
);
let
best_worker
=
softmax_sample
(
&
worker_logits
,
temperature
);
let
candidates
=
softmax_sample
(
&
worker_logits
,
temperature
);
// If multiple candidates (tied), use tree size as tie-breaker
// If tree sizes are also equal, min_by_key uses HashMap iteration order (pseudo-random)
let
best_worker
=
if
candidates
.len
()
>
1
{
tracing
::
info!
(
"Multiple workers tied with same logit, using tree size as tie-breaker"
);
*
candidates
.iter
()
.min_by_key
(|
worker
|
{
request
.overlaps
.tree_sizes
.get
(
worker
)
.copied
()
.unwrap_or
(
0
)
})
.expect
(
"candidates should not be empty"
)
}
else
{
candidates
[
0
]
};
let
best_logit
=
worker_logits
[
&
best_worker
];
let
best_logit
=
worker_logits
[
&
best_worker
];
let
best_overlap
=
*
overlaps
.get
(
&
best_worker
)
.unwrap_or
(
&
0
);
let
best_overlap
=
*
overlaps
.get
(
&
best_worker
)
.unwrap_or
(
&
0
);
...
@@ -562,12 +583,20 @@ impl WorkerSelector for DefaultWorkerSelector {
...
@@ -562,12 +583,20 @@ impl WorkerSelector for DefaultWorkerSelector {
.map
(|
blocks
|
format!
(
", total blocks: {}"
,
blocks
))
.map
(|
blocks
|
format!
(
", total blocks: {}"
,
blocks
))
.unwrap_or_default
();
.unwrap_or_default
();
let
tree_size
=
request
.overlaps
.tree_sizes
.get
(
&
best_worker
)
.copied
()
.unwrap_or
(
0
);
tracing
::
info!
(
tracing
::
info!
(
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}{}"
,
"Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks:
{}, tree size:
{}{}"
,
best_worker
.worker_id
,
best_worker
.worker_id
,
best_worker
.dp_rank
,
best_worker
.dp_rank
,
best_logit
,
best_logit
,
best_overlap
,
best_overlap
,
tree_size
,
total_blocks_info
total_blocks_info
);
);
...
@@ -593,26 +622,33 @@ mod tests {
...
@@ -593,26 +622,33 @@ mod tests {
// Test with different temperatures
// Test with different temperatures
for
temperature
in
&
[
0.1
,
1.0
,
10.0
]
{
for
temperature
in
&
[
0.1
,
1.0
,
10.0
]
{
let
result
=
softmax_sample
(
&
logits
,
*
temperature
);
let
result
=
softmax_sample
(
&
logits
,
*
temperature
);
assert_eq!
(
result
,
worker
,
"Should return the only available worker"
);
assert_eq!
(
result
.len
(),
1
,
"Should return exactly one worker"
);
assert_eq!
(
result
[
0
],
worker
,
"Should return the only available worker"
);
}
}
// Test with different logit values
// Test with different logit values
logits
.clear
();
logits
.clear
();
logits
.insert
(
worker
,
-
100.0
);
// Very negative value
logits
.insert
(
worker
,
-
100.0
);
// Very negative value
assert_eq!
(
softmax_sample
(
&
logits
,
1.0
),
worker
);
let
result
=
softmax_sample
(
&
logits
,
1.0
);
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
0
],
worker
);
logits
.clear
();
logits
.clear
();
logits
.insert
(
worker
,
100.0
);
// Very positive value
logits
.insert
(
worker
,
100.0
);
// Very positive value
assert_eq!
(
softmax_sample
(
&
logits
,
1.0
),
worker
);
let
result
=
softmax_sample
(
&
logits
,
1.0
);
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
0
],
worker
);
logits
.clear
();
logits
.clear
();
logits
.insert
(
worker
,
0.0
);
// Zero value
logits
.insert
(
worker
,
0.0
);
// Zero value
assert_eq!
(
softmax_sample
(
&
logits
,
1.0
),
worker
);
let
result
=
softmax_sample
(
&
logits
,
1.0
);
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
0
],
worker
);
}
}
#[test]
#[test]
fn
test_softmax_sample_zero_temperature
()
{
fn
test_softmax_sample_zero_temperature
()
{
// Test that with temperature 0, softmax_sample returns
the
key with smallest logit
// Test that with temperature 0, softmax_sample returns
all
key
s
with smallest logit
let
mut
logits
=
HashMap
::
new
();
let
mut
logits
=
HashMap
::
new
();
let
worker1
=
WorkerWithDpRank
::
from_worker_id
(
1
);
let
worker1
=
WorkerWithDpRank
::
from_worker_id
(
1
);
let
worker2
=
WorkerWithDpRank
::
from_worker_id
(
2
);
let
worker2
=
WorkerWithDpRank
::
from_worker_id
(
2
);
...
@@ -623,14 +659,37 @@ mod tests {
...
@@ -623,14 +659,37 @@ mod tests {
logits
.insert
(
worker3
,
7.0
);
logits
.insert
(
worker3
,
7.0
);
logits
.insert
(
worker4
,
3.5
);
logits
.insert
(
worker4
,
3.5
);
// With temperature 0, should always return worker 2 (smallest logit)
// With temperature 0, should always return only worker2 (smallest logit)
for
_
in
0
..
10
{
let
result
=
softmax_sample
(
&
logits
,
0.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
assert_eq!
(
assert_eq!
(
result
.len
(),
result
,
worker2
,
1
,
"Should return worker with smallest logit when temperature is 0"
"Should return one worker when there's no tie"
);
);
}
assert_eq!
(
result
[
0
],
worker2
,
"Should return worker with smallest logit when temperature is 0"
);
// Test with tied minimum logits
logits
.clear
();
let
worker5
=
WorkerWithDpRank
::
from_worker_id
(
5
);
let
worker6
=
WorkerWithDpRank
::
from_worker_id
(
6
);
logits
.insert
(
worker1
,
5.0
);
logits
.insert
(
worker2
,
3.0
);
// Tied for smallest
logits
.insert
(
worker5
,
3.0
);
// Tied for smallest
logits
.insert
(
worker6
,
7.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
assert_eq!
(
result
.len
(),
2
,
"Should return all workers with smallest logit when tied"
);
assert
!
(
result
.contains
(
&
worker2
)
&&
result
.contains
(
&
worker5
),
"Should contain both tied workers"
);
// Test with negative values
// Test with negative values
logits
.clear
();
logits
.clear
();
...
@@ -642,6 +701,10 @@ mod tests {
...
@@ -642,6 +701,10 @@ mod tests {
logits
.insert
(
worker30
,
0.0
);
logits
.insert
(
worker30
,
0.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
assert_eq!
(
result
,
worker20
,
"Should handle negative logits correctly"
);
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
0
],
worker20
,
"Should handle negative logits correctly"
);
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment