Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
1368ccd6
Unverified
Commit
1368ccd6
authored
Mar 24, 2026
by
jthomson04
Committed by
GitHub
Mar 25, 2026
Browse files
perf: Miscellaneous router perf improvements (#7477)
Signed-off-by:
jthomson04
<
jwillthomson19@gmail.com
>
parent
0adfd98d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
298 additions
and
183 deletions
+298
-183
lib/kv-router/src/protocols.rs
lib/kv-router/src/protocols.rs
+62
-15
lib/kv-router/src/scheduling/config.rs
lib/kv-router/src/scheduling/config.rs
+31
-4
lib/kv-router/src/scheduling/selector.rs
lib/kv-router/src/scheduling/selector.rs
+157
-108
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+36
-47
lib/llm/src/kv_router/metrics.rs
lib/llm/src/kv_router/metrics.rs
+7
-7
lib/llm/src/kv_router/worker_query.rs
lib/llm/src/kv_router/worker_query.rs
+2
-2
lib/mocker/src/replay/router/offline.rs
lib/mocker/src/replay/router/offline.rs
+2
-0
lib/mocker/src/replay/router/online.rs
lib/mocker/src/replay/router/online.rs
+1
-0
No files found.
lib/kv-router/src/protocols.rs
View file @
1368ccd6
...
@@ -31,6 +31,29 @@ pub struct BlockHashOptions<'a> {
...
@@ -31,6 +31,29 @@ pub struct BlockHashOptions<'a> {
pub
is_eagle
:
Option
<
bool
>
,
pub
is_eagle
:
Option
<
bool
>
,
}
}
#[inline]
fn
hash_block_no_mm
(
chunk
:
&
[
u32
],
seed
:
u64
,
scratch_bytes
:
&
mut
Vec
<
u8
>
)
->
LocalBlockHash
{
#[cfg(target_endian
=
"little"
)]
{
let
_
=
scratch_bytes
;
// SAFETY: `u32` is plain-old-data, and on little-endian targets its in-memory
// representation matches the `to_le_bytes()` sequence used for hashing.
let
chunk_bytes
=
unsafe
{
std
::
slice
::
from_raw_parts
(
chunk
.as_ptr
()
.cast
::
<
u8
>
(),
std
::
mem
::
size_of_val
(
chunk
))
};
LocalBlockHash
(
xxh3
::
xxh3_64_with_seed
(
chunk_bytes
,
seed
))
}
#[cfg(not(target_endian
=
"little"
))]
{
scratch_bytes
.clear
();
for
&
token
in
chunk
{
scratch_bytes
.extend_from_slice
(
&
token
.to_le_bytes
());
}
LocalBlockHash
(
xxh3
::
xxh3_64_with_seed
(
scratch_bytes
,
seed
))
}
}
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata
/// and LoRA adapter identity.
/// and LoRA adapter identity.
///
///
...
@@ -56,35 +79,42 @@ pub fn compute_block_hash_for_seq(
...
@@ -56,35 +79,42 @@ pub fn compute_block_hash_for_seq(
Some
(
name
)
=>
XXH3_SEED
.wrapping_add
(
xxh3
::
xxh3_64
(
name
.as_bytes
())),
Some
(
name
)
=>
XXH3_SEED
.wrapping_add
(
xxh3
::
xxh3_64
(
name
.as_bytes
())),
None
=>
XXH3_SEED
,
None
=>
XXH3_SEED
,
};
};
let
is_eagle_flag
=
options
.is_eagle
.unwrap_or
(
false
);
let
is_eagle_flag
=
options
.is_eagle
.unwrap_or
(
false
);
let
stride
=
kv_block_size
as
usize
;
let
stride
=
kv_block_size
as
usize
;
let
window_size
=
if
is_eagle_flag
{
stride
+
1
}
else
{
stride
};
let
window_size
=
if
is_eagle_flag
{
stride
+
1
}
else
{
stride
};
let
estimated_blocks
=
if
is_eagle_flag
{
let
mut
hashes
=
Vec
::
new
();
tokens
.len
()
.saturating_sub
(
1
)
/
stride
}
else
{
tokens
.len
()
/
stride
};
let
mut
hashes
=
Vec
::
with_capacity
(
estimated_blocks
);
let
mut
bytes
=
Vec
::
with_capacity
(
window_size
*
std
::
mem
::
size_of
::
<
u32
>
());
let
mut
mm_hashes
=
Vec
::
new
();
let
mut
block_idx
=
0
;
let
mut
block_idx
=
0
;
let
mut
start
=
0
;
let
mut
start
=
0
;
while
start
+
window_size
<=
tokens
.len
()
{
while
start
+
window_size
<=
tokens
.len
()
{
let
chunk
=
&
tokens
[
start
..
start
+
window_size
];
let
chunk
=
&
tokens
[
start
..
start
+
window_size
];
let
mut
bytes
:
Vec
<
u8
>
=
chunk
.iter
()
.flat_map
(|
&
num
|
num
.to_le_bytes
())
.collect
();
if
let
Some
(
mm_infos
)
=
options
.block_mm_infos
if
let
Some
(
mm_infos
)
=
options
.block_mm_infos
&&
let
Some
(
Some
(
block_mm_info
))
=
mm_infos
.get
(
block_idx
)
&&
let
Some
(
Some
(
block_mm_info
))
=
mm_infos
.get
(
block_idx
)
{
{
let
mut
mm_hashes
:
Vec
<
u64
>
=
block_mm_info
bytes
.clear
();
.mm_objects
for
&
token
in
chunk
{
.iter
()
bytes
.extend_from_slice
(
&
token
.to_le_bytes
());
.map
(|
obj
|
obj
.mm_hash
)
}
.collect
();
mm_hashes
.clear
();
mm_hashes
.extend
(
block_mm_info
.mm_objects
.iter
()
.map
(|
obj
|
obj
.mm_hash
));
mm_hashes
.sort_unstable
();
mm_hashes
.sort_unstable
();
for
mm_hash
in
mm_hashes
{
for
&
mm_hash
in
&
mm_hashes
{
bytes
.extend_from_slice
(
&
mm_hash
.to_le_bytes
());
bytes
.extend_from_slice
(
&
mm_hash
.to_le_bytes
());
}
}
}
hashes
.push
(
LocalBlockHash
(
xxh3
::
xxh3_64_with_seed
(
&
bytes
,
seed
)));
hashes
.push
(
LocalBlockHash
(
xxh3
::
xxh3_64_with_seed
(
&
bytes
,
seed
)));
}
else
{
hashes
.push
(
hash_block_no_mm
(
chunk
,
seed
,
&
mut
bytes
));
}
start
+=
stride
;
start
+=
stride
;
block_idx
+=
1
;
block_idx
+=
1
;
...
@@ -110,8 +140,25 @@ pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<Sequen
...
@@ -110,8 +140,25 @@ pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<Sequen
let
current_block_hash
=
block_hashes
[
i
]
.0
;
let
current_block_hash
=
block_hashes
[
i
]
.0
;
let
combined
=
[
parent_seq_hash
,
current_block_hash
];
let
combined
=
[
parent_seq_hash
,
current_block_hash
];
let
bytes
:
Vec
<
u8
>
=
combined
.iter
()
.flat_map
(|
&
num
|
num
.to_le_bytes
())
.collect
();
#[cfg(target_endian
=
"little"
)]
let
seq_hash
=
compute_hash
(
&
bytes
);
let
seq_hash
=
{
// SAFETY: `u64` is plain-old-data, and on little-endian targets its in-memory
// representation matches the `to_le_bytes()` sequence used by the previous code.
let
bytes
=
unsafe
{
std
::
slice
::
from_raw_parts
(
combined
.as_ptr
()
.cast
::
<
u8
>
(),
std
::
mem
::
size_of_val
(
&
combined
),
)
};
compute_hash
(
bytes
)
};
#[cfg(not(target_endian
=
"little"
))]
let
seq_hash
=
{
let
mut
bytes
=
[
0_u8
;
std
::
mem
::
size_of
::
<
u64
>
()
*
2
];
bytes
[
..
8
]
.copy_from_slice
(
&
parent_seq_hash
.to_le_bytes
());
bytes
[
8
..
]
.copy_from_slice
(
&
current_block_hash
.to_le_bytes
());
compute_hash
(
&
bytes
)
};
sequence_hashes
.push
(
seq_hash
);
sequence_hashes
.push
(
seq_hash
);
}
}
...
...
lib/kv-router/src/scheduling/config.rs
View file @
1368ccd6
...
@@ -9,7 +9,9 @@ use rand::Rng;
...
@@ -9,7 +9,9 @@ use rand::Rng;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::{
Validate
,
ValidationError
};
use
validator
::{
Validate
,
ValidationError
};
use
crate
::
protocols
::{
BlockHashOptions
,
compute_block_hash_for_seq
,
compute_seq_hash_for_block
};
use
crate
::
protocols
::{
BlockHashOptions
,
LocalBlockHash
,
compute_block_hash_for_seq
,
compute_seq_hash_for_block
,
};
const
fn
default_min_initial_workers
()
->
usize
{
const
fn
default_min_initial_workers
()
->
usize
{
1
1
...
@@ -218,6 +220,7 @@ impl KvRouterConfig {
...
@@ -218,6 +220,7 @@ impl KvRouterConfig {
block_size
:
u32
,
block_size
:
u32
,
config_override
:
Option
<&
RouterConfigOverride
>
,
config_override
:
Option
<&
RouterConfigOverride
>
,
hash_options
:
BlockHashOptions
<
'_
>
,
hash_options
:
BlockHashOptions
<
'_
>
,
precomputed_block_hashes
:
Option
<&
[
LocalBlockHash
]
>
,
)
->
Option
<
Vec
<
u64
>>
{
)
->
Option
<
Vec
<
u64
>>
{
if
!
self
.router_track_active_blocks
{
if
!
self
.router_track_active_blocks
{
return
None
;
return
None
;
...
@@ -233,8 +236,14 @@ impl KvRouterConfig {
...
@@ -233,8 +236,14 @@ impl KvRouterConfig {
.unwrap_or
(
self
.router_assume_kv_reuse
);
.unwrap_or
(
self
.router_assume_kv_reuse
);
if
assume_kv_reuse
{
if
assume_kv_reuse
{
let
block_hashes
=
compute_block_hash_for_seq
(
tokens
,
block_size
,
hash_options
);
let
block_hashes
=
match
precomputed_block_hashes
{
Some
(
compute_seq_hash_for_block
(
&
block_hashes
))
Some
(
block_hashes
)
=>
block_hashes
,
None
=>
{
let
computed
=
compute_block_hash_for_seq
(
tokens
,
block_size
,
hash_options
);
return
Some
(
compute_seq_hash_for_block
(
&
computed
));
}
};
Some
(
compute_seq_hash_for_block
(
block_hashes
))
}
else
{
}
else
{
let
mut
rng
=
rand
::
rng
();
let
mut
rng
=
rand
::
rng
();
Some
((
0
..
num_blocks
)
.map
(|
_
|
rng
.random
::
<
u64
>
())
.collect
())
Some
((
0
..
num_blocks
)
.map
(|
_
|
rng
.random
::
<
u64
>
())
.collect
())
...
@@ -305,7 +314,7 @@ mod tests {
...
@@ -305,7 +314,7 @@ mod tests {
];
];
let
without_mm
=
cfg
let
without_mm
=
cfg
.compute_seq_hashes_for_tracking
(
&
tokens
,
2
,
None
,
BlockHashOptions
::
default
())
.compute_seq_hashes_for_tracking
(
&
tokens
,
2
,
None
,
BlockHashOptions
::
default
()
,
None
)
.unwrap
();
.unwrap
();
let
with_mm
=
cfg
let
with_mm
=
cfg
.compute_seq_hashes_for_tracking
(
.compute_seq_hashes_for_tracking
(
...
@@ -316,9 +325,27 @@ mod tests {
...
@@ -316,9 +325,27 @@ mod tests {
block_mm_infos
:
Some
(
&
mm_infos
),
block_mm_infos
:
Some
(
&
mm_infos
),
..
Default
::
default
()
..
Default
::
default
()
},
},
None
,
)
)
.unwrap
();
.unwrap
();
assert_ne!
(
without_mm
,
with_mm
);
assert_ne!
(
without_mm
,
with_mm
);
}
}
#[test]
fn
compute_seq_hashes_for_tracking_uses_precomputed_block_hashes
()
{
let
config
=
KvRouterConfig
::
default
();
let
tokens
:
Vec
<
u32
>
=
(
0
..
8
)
.collect
();
let
precomputed
=
vec!
[
LocalBlockHash
(
11
),
LocalBlockHash
(
29
)];
let
seq_hashes
=
config
.compute_seq_hashes_for_tracking
(
&
tokens
,
4
,
None
,
BlockHashOptions
::
default
(),
Some
(
&
precomputed
),
);
assert_eq!
(
seq_hashes
,
Some
(
compute_seq_hash_for_block
(
&
precomputed
)));
}
}
}
lib/kv-router/src/scheduling/selector.rs
View file @
1368ccd6
...
@@ -22,36 +22,52 @@ pub trait WorkerSelector<C: WorkerConfigLike> {
...
@@ -22,36 +22,52 @@ pub trait WorkerSelector<C: WorkerConfigLike> {
}
}
/// Helper function for softmax sampling.
/// Helper function for softmax sampling.
/// Returns
a vec of workers: multiple if tied, single if sampled
.
/// Returns
the selected worker and its logit
.
fn
softmax_sample
(
fn
softmax_sample
(
logits
:
&
HashMap
<
WorkerWithDpRank
,
f64
>
,
logits
:
&
HashMap
<
WorkerWithDpRank
,
f64
>
,
temperature
:
f64
,
temperature
:
f64
,
)
->
Vec
<
WorkerWithDpRank
>
{
)
->
(
WorkerWithDpRank
,
f64
)
{
let
mut
rng
=
rand
::
rng
();
softmax_sample_with_sample
(
logits
,
temperature
,
rng
.random
())
}
fn
softmax_sample_with_sample
(
logits
:
&
HashMap
<
WorkerWithDpRank
,
f64
>
,
temperature
:
f64
,
sample
:
f64
,
)
->
(
WorkerWithDpRank
,
f64
)
{
if
logits
.is_empty
()
{
if
logits
.is_empty
()
{
panic!
(
"Empty logits for softmax sampling"
);
panic!
(
"Empty logits for softmax sampling"
);
}
}
// Guard:
if
temperature
is 0
, return a
ll keys with the smallest logit value (ties)
// Guard:
at zero
temperature, return a
minimum-logit worker directly.
if
temperature
==
0.0
{
if
temperature
==
0.0
{
let
min_logit
=
logits
.values
()
.fold
(
f64
::
INFINITY
,
|
a
,
&
b
|
a
.min
(
b
));
let
mut
logit_iter
=
logits
.iter
();
let
(
first_key
,
first_logit
)
=
logit_iter
.next
()
.unwrap
();
let
min_keys
:
Vec
<
_
>
=
logits
.iter
()
let
mut
min_logit
=
first_logit
;
.filter
(|
&
(
_
,
&
v
)|
v
==
min_logit
)
let
mut
min_key
=
first_key
;
.map
(|(
k
,
_
)|
*
k
)
for
(
key
,
logit
)
in
logit_iter
{
.collect
();
if
logit
<
min_logit
{
min_logit
=
logit
;
min_key
=
key
;
}
}
return
min_key
s
;
return
(
*
min_key
,
*
min_logit
)
;
}
}
let
keys
:
Vec
<
_
>
=
logits
.keys
()
.copied
()
.collect
();
let
entries
:
Vec
<
_
>
=
logits
let
values
:
Vec
<
_
>
=
logits
.values
()
.copied
()
.collect
();
.iter
()
.map
(|(
worker
,
logit
)|
(
*
worker
,
*
logit
))
.collect
();
let
values
:
Vec
<
_
>
=
entries
.iter
()
.map
(|(
_
,
logit
)|
*
logit
)
.collect
();
let
min_val
=
values
.iter
()
.fold
(
f64
::
INFINITY
,
|
a
,
&
b
|
a
.min
(
b
));
let
min_val
=
values
.iter
()
.fold
(
f64
::
INFINITY
,
|
a
,
&
b
|
a
.min
(
b
));
let
max_val
=
values
.iter
()
.fold
(
f64
::
NEG_INFINITY
,
|
a
,
&
b
|
a
.max
(
b
));
let
max_val
=
values
.iter
()
.fold
(
f64
::
NEG_INFINITY
,
|
a
,
&
b
|
a
.max
(
b
));
let
probabilities
=
if
min_val
==
max_val
{
let
probabilities
=
if
min_val
==
max_val
{
vec!
[
1.0
/
key
s
.len
()
as
f64
;
key
s
.len
()]
vec!
[
1.0
/
entrie
s
.len
()
as
f64
;
entrie
s
.len
()]
}
else
{
}
else
{
// Fused normalize -> negate -> scale -> exp, then normalize probabilities
// Fused normalize -> negate -> scale -> exp, then normalize probabilities
let
range
=
max_val
-
min_val
;
let
range
=
max_val
-
min_val
;
...
@@ -63,19 +79,16 @@ fn softmax_sample(
...
@@ -63,19 +79,16 @@ fn softmax_sample(
probs
probs
};
};
let
mut
rng
=
rand
::
rng
();
let
sample
:
f64
=
rng
.random
();
let
mut
cumsum
=
0.0
;
let
mut
cumsum
=
0.0
;
for
(
i
,
&
prob
)
in
probabilities
.iter
()
.enumerate
()
{
for
(
i
,
&
prob
)
in
probabilities
.iter
()
.enumerate
()
{
cumsum
+=
prob
;
cumsum
+=
prob
;
if
sample
<=
cumsum
{
if
sample
<=
cumsum
{
return
vec!
[
key
s
[
i
]
]
;
return
entrie
s
[
i
];
}
}
}
}
// Fallback to last key (shouldn't normally reach here)
// Fallback to last key (shouldn't normally reach here)
vec!
[
keys
[
key
s
.len
()
-
1
]
]
entries
[
entrie
s
.len
()
-
1
]
}
}
/// Default implementation matching the Python _cost_function.
/// Default implementation matching the Python _cost_function.
...
@@ -118,76 +131,92 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
...
@@ -118,76 +131,92 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
let
decode_blocks
=
&
request
.decode_blocks
;
let
decode_blocks
=
&
request
.decode_blocks
;
let
prefill_tokens
=
&
request
.prefill_tokens
;
let
prefill_tokens
=
&
request
.prefill_tokens
;
let
mut
worker_logits
=
HashMap
::
new
();
let
overlap_weight
=
request
let
overlap_weight
=
request
.router_config_override
.router_config_override
.as_ref
()
.as_ref
()
.and_then
(|
cfg
|
cfg
.overlap_score_weight
)
.and_then
(|
cfg
|
cfg
.overlap_score_weight
)
.unwrap_or
(
self
.kv_router_config.overlap_score_weight
);
.unwrap_or
(
self
.kv_router_config.overlap_score_weight
);
for
(
worker_id
,
config
)
in
workers
let
temperature
=
request
.iter
()
.router_config_override
.filter
(|(
wid
,
_
)|
allowed_ids
.is_none_or
(|
ids
|
ids
.contains
(
wid
)))
.as_ref
()
{
.and_then
(|
cfg
|
cfg
.router_temperature
)
let
data_parallel_size
=
config
.data_parallel_size
();
.unwrap_or
(
self
.kv_router_config.router_temperature
);
let
data_parallel_start_rank
=
config
.data_parallel_start_rank
();
for
dp_rank
in
data_parallel_start_rank
..
(
data_parallel_start_rank
+
data_parallel_size
)
let
get_score
=
|
worker
:
WorkerWithDpRank
|
->
f64
{
{
let
overlap
=
*
overlaps
.get
(
&
worker
)
.unwrap_or
(
&
0
);
let
worker
=
WorkerWithDpRank
::
new
(
*
worker_id
,
dp_rank
);
let
prefill_token
=
*
prefill_tokens
.get
(
&
worker
)
.unwrap_or
(
&
isl
);
let
potential_prefill_block
=
(
prefill_token
as
f64
)
/
(
block_size
as
f64
);
let
decode_block
=
*
decode_blocks
.get
(
&
worker
)
.unwrap_or
(
&
(
potential_prefill_block
.floor
()
as
usize
))
as
f64
;
let
overlap
=
*
overlap
s
.get
(
&
worker
)
.unwrap_or
(
&
0
)
;
let
logit
=
overlap
_weight
*
potential_prefill_block
+
decode_block
;
let
prefill_token
=
*
prefill_tokens
.get
(
&
worker
)
.unwrap_or
(
&
isl
);
tracing
::
debug!
(
let
potential_prefill_block
=
(
prefill_token
as
f64
)
/
(
block_size
as
f64
);
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3}
\
= {overlap_weight:.1} * prefill_blocks + decode_blocks
\
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}"
,
worker
.worker_id
,
worker
.dp_rank
);
let
decode_block
=
*
decode_blocks
logit
.get
(
&
worker
)
};
.unwrap_or
(
&
(
potential_prefill_block
.floor
()
as
usize
))
as
f64
;
let
logit
=
overlap_weight
*
potential_prefill_block
+
decode_block
;
let
worker_iter
=
workers
.iter
()
.filter
(
move
|(
wid
,
_
)|
allowed_ids
.is_none_or
(|
ids
|
ids
.contains
(
wid
)))
.flat_map
(|(
worker_id
,
config
)|
{
let
data_parallel_size
=
config
.data_parallel_size
();
let
data_parallel_start_rank
=
config
.data_parallel_start_rank
();
(
data_parallel_start_rank
..
(
data_parallel_start_rank
+
data_parallel_size
))
.map
(
move
|
dp_rank
|
WorkerWithDpRank
::
new
(
*
worker_id
,
dp_rank
))
});
worker_logits
.insert
(
worker
,
logit
);
let
(
best_worker
,
best_logit
)
=
if
temperature
==
0.0
{
let
mut
min_workers
=
Vec
::
new
();
let
mut
min_score
=
f64
::
INFINITY
;
for
worker
in
worker_iter
{
let
score
=
get_score
(
worker
);
if
score
<
min_score
{
min_workers
.clear
();
min_workers
.push
(
worker
);
min_score
=
score
;
}
else
if
score
==
min_score
{
min_workers
.push
(
worker
);
}
}
if
min_workers
.len
()
>
1
{
tracing
::
debug!
(
tracing
::
debug!
(
"Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3}
\
"Multiple workers tied with same logit, using tree size as tie-breaker"
= {overlap_weight:.1} * prefill_blocks + decode_blocks
\
= {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}"
,
worker
.worker_id
,
worker
.dp_rank
);
);
}
let
tree_sizes
:
Vec
<
(
usize
,
&
WorkerWithDpRank
)
>
=
min_workers
}
.iter
()
.map
(|
w
|
(
request
.overlaps.tree_sizes
.get
(
w
)
.copied
()
.unwrap_or
(
0
),
w
))
let
temperature
=
request
.collect
();
.router_config_override
.as_ref
()
if
tree_sizes
.iter
()
.all
(|(
s
,
_
)|
*
s
==
tree_sizes
[
0
]
.0
)
{
.and_then
(|
cfg
|
cfg
.router_temperature
)
let
idx
=
rand
::
rng
()
.random_range
(
0
..
min_workers
.len
());
.unwrap_or
(
self
.kv_router_config.router_temperature
);
(
min_workers
[
idx
],
min_score
)
let
candidates
=
softmax_sample
(
&
worker_logits
,
temperature
);
}
else
{
let
(
_
,
worker
)
=
*
tree_sizes
.iter
()
.min_by_key
(|(
s
,
_
)|
*
s
)
.unwrap
();
let
best_worker
=
if
candidates
.len
()
>
1
{
(
*
worker
,
min_score
)
tracing
::
debug!
(
}
"Multiple workers tied with same logit, using tree size as tie-breaker"
);
let
tree_sizes
:
Vec
<
(
usize
,
&
WorkerWithDpRank
)
>
=
candidates
.iter
()
.map
(|
w
|
(
request
.overlaps.tree_sizes
.get
(
w
)
.copied
()
.unwrap_or
(
0
),
w
))
.collect
();
if
tree_sizes
.iter
()
.all
(|(
s
,
_
)|
*
s
==
tree_sizes
[
0
]
.0
)
{
let
idx
=
rand
::
rng
()
.random_range
(
0
..
candidates
.len
());
candidates
[
idx
]
}
else
{
}
else
{
*
tree_sizes
.iter
()
.min_by_key
(|(
s
,
_
)|
*
s
)
.unwrap
()
.1
(
min_workers
[
0
],
min_score
)
}
}
}
else
{
}
else
{
candidates
[
0
]
let
mut
worker_logits
=
HashMap
::
new
();
};
for
worker
in
worker_iter
{
let
score
=
get_score
(
worker
);
worker_logits
.insert
(
worker
,
score
);
}
let
best_logit
=
worker_logits
[
&
best_worker
];
softmax_sample
(
&
worker_logits
,
temperature
)
};
if
self
.worker_type
==
"decode"
{
if
self
.worker_type
==
"decode"
{
tracing
::
info!
(
tracing
::
info!
(
...
@@ -246,31 +275,22 @@ mod tests {
...
@@ -246,31 +275,22 @@ mod tests {
fn
test_softmax_sample_single_key
()
{
fn
test_softmax_sample_single_key
()
{
let
mut
logits
=
HashMap
::
new
();
let
mut
logits
=
HashMap
::
new
();
let
worker
=
WorkerWithDpRank
::
from_worker_id
(
42
);
let
worker
=
WorkerWithDpRank
::
from_worker_id
(
42
);
logits
.insert
(
worker
,
0.5
);
for
(
logit
,
temperature
)
in
[
(
0.5
,
0.1
),
for
temperature
in
&
[
0.1
,
1.0
,
10.0
]
{
(
0.5
,
1.0
),
let
result
=
softmax_sample
(
&
logits
,
*
temperature
);
(
0.5
,
10.0
),
assert_eq!
(
result
.len
(),
1
,
"Should return exactly one worker"
);
(
-
100.0
,
1.0
),
assert_eq!
(
result
[
0
],
worker
,
"Should return the only available worker"
);
(
100.0
,
1.0
),
(
0.0
,
1.0
),
(
0.0
,
0.0
),
]
{
logits
.clear
();
logits
.insert
(
worker
,
logit
);
let
result
=
softmax_sample
(
&
logits
,
temperature
);
assert_eq!
(
result
.0
,
worker
,
"Should return the only available worker"
);
assert_eq!
(
result
.1
,
logit
,
"Should return the selected worker's logit"
);
}
}
logits
.clear
();
logits
.insert
(
worker
,
-
100.0
);
let
result
=
softmax_sample
(
&
logits
,
1.0
);
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
0
],
worker
);
logits
.clear
();
logits
.insert
(
worker
,
100.0
);
let
result
=
softmax_sample
(
&
logits
,
1.0
);
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
0
],
worker
);
logits
.clear
();
logits
.insert
(
worker
,
0.0
);
let
result
=
softmax_sample
(
&
logits
,
1.0
);
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
result
[
0
],
worker
);
}
}
#[test]
#[test]
...
@@ -287,13 +307,12 @@ mod tests {
...
@@ -287,13 +307,12 @@ mod tests {
let
result
=
softmax_sample
(
&
logits
,
0.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
assert_eq!
(
assert_eq!
(
result
.len
(),
result
.0
,
worker2
,
1
,
"Should return worker with smallest logit when temperature is 0"
"Should return one worker when there's no tie"
);
);
assert_eq!
(
assert_eq!
(
result
[
0
],
worker2
,
result
.1
,
3.0
,
"Should return
worker wi
th smallest logit when temperature is 0"
"Should return th
e
smallest logit when temperature is 0"
);
);
logits
.clear
();
logits
.clear
();
...
@@ -305,15 +324,11 @@ mod tests {
...
@@ -305,15 +324,11 @@ mod tests {
logits
.insert
(
worker6
,
7.0
);
logits
.insert
(
worker6
,
7.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
assert_eq!
(
result
.len
(),
2
,
"Should return all workers with smallest logit when tied"
);
assert
!
(
assert
!
(
result
.
contains
(
&
worker2
)
&&
result
.
contains
(
&
worker5
)
,
result
.
0
==
worker2
||
result
.
0
==
worker5
,
"Should
contain both
tied
w
or
kers
"
"Should
return one of the workers
tied
f
or
the smallest logit
"
);
);
assert_eq!
(
result
.1
,
3.0
,
"Should return the tied minimum logit"
);
logits
.clear
();
logits
.clear
();
let
worker10
=
WorkerWithDpRank
::
from_worker_id
(
10
);
let
worker10
=
WorkerWithDpRank
::
from_worker_id
(
10
);
...
@@ -324,10 +339,44 @@ mod tests {
...
@@ -324,10 +339,44 @@ mod tests {
logits
.insert
(
worker30
,
0.0
);
logits
.insert
(
worker30
,
0.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
let
result
=
softmax_sample
(
&
logits
,
0.0
);
assert_eq!
(
result
.len
(),
1
);
assert_eq!
(
assert_eq!
(
result
[
0
]
,
worker20
,
result
.0
,
worker20
,
"Should handle negative logits correctly"
"Should handle negative logits correctly"
);
);
assert_eq!
(
result
.1
,
-
5.0
,
"Should return the minimum negative logit"
);
}
#[test]
fn
test_softmax_sample_with_sample_returns_selected_logit
()
{
let
worker1
=
WorkerWithDpRank
::
from_worker_id
(
1
);
let
worker2
=
WorkerWithDpRank
::
from_worker_id
(
2
);
let
worker3
=
WorkerWithDpRank
::
from_worker_id
(
3
);
let
logits
=
HashMap
::
from
([(
worker1
,
0.0
),
(
worker2
,
3.0
),
(
worker3
,
9.0
)]);
let
entries
:
Vec
<
_
>
=
logits
.iter
()
.map
(|(
worker
,
logit
)|
(
*
worker
,
*
logit
))
.collect
();
let
values
:
Vec
<
_
>
=
entries
.iter
()
.map
(|(
_
,
logit
)|
*
logit
)
.collect
();
let
min_val
=
values
.iter
()
.fold
(
f64
::
INFINITY
,
|
a
,
&
b
|
a
.min
(
b
));
let
max_val
=
values
.iter
()
.fold
(
f64
::
NEG_INFINITY
,
|
a
,
&
b
|
a
.max
(
b
));
let
temperature
=
1.0
;
let
range
=
max_val
-
min_val
;
let
scaled
:
Vec
<
f64
>
=
values
.iter
()
.map
(|
&
v
|
-
(
v
/
range
)
/
temperature
)
.collect
();
let
max_scaled
=
scaled
.iter
()
.fold
(
f64
::
NEG_INFINITY
,
|
a
,
&
b
|
a
.max
(
b
));
let
mut
probabilities
:
Vec
<
f64
>
=
scaled
.iter
()
.map
(|
&
v
|
(
v
-
max_scaled
)
.exp
())
.collect
();
let
sum
:
f64
=
probabilities
.iter
()
.sum
();
probabilities
.iter_mut
()
.for_each
(|
p
|
*
p
/=
sum
);
let
target_idx
=
entries
.iter
()
.position
(|(
_
,
logit
)|
*
logit
>
min_val
)
.expect
(
"expected at least one non-minimum logit"
);
let
cumsum_before
:
f64
=
probabilities
.iter
()
.take
(
target_idx
)
.sum
();
let
sample
=
cumsum_before
+
probabilities
[
target_idx
]
/
2.0
;
let
result
=
softmax_sample_with_sample
(
&
logits
,
temperature
,
sample
);
assert_eq!
(
result
,
entries
[
target_idx
]);
}
}
}
}
lib/llm/src/kv_router.rs
View file @
1368ccd6
...
@@ -429,19 +429,26 @@ where
...
@@ -429,19 +429,26 @@ where
}
}
let
isl_tokens
=
tokens
.len
();
let
isl_tokens
=
tokens
.len
();
let
hash_options
=
BlockHashOptions
{
block_mm_infos
,
lora_name
:
lora_name
.as_deref
(),
is_eagle
:
Some
(
self
.is_eagle
),
};
let
block_hashes
=
tracing
::
info_span!
(
"kv_router.compute_block_hashes"
)
.in_scope
(||
{
let
block_hashes
=
tracing
::
info_span!
(
"kv_router.compute_block_hashes"
)
compute_block_hash_for_seq
(
.in_scope
(||
compute_block_hash_for_seq
(
tokens
,
self
.block_size
,
hash_options
));
let
hash_elapsed
=
start
.elapsed
();
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let
maybe_seq_hashes
=
tracing
::
info_span!
(
"kv_router.compute_seq_hashes"
)
.in_scope
(||
{
self
.kv_router_config
.compute_seq_hashes_for_tracking
(
tokens
,
tokens
,
self
.block_size
,
self
.block_size
,
BlockHashOptions
{
router_config_override
,
block_mm_infos
,
hash_options
,
lora_name
:
lora_name
.as_deref
(),
Some
(
&
block_hashes
),
is_eagle
:
Some
(
self
.is_eagle
),
},
)
)
});
});
let
hash_elapsed
=
start
.elapsed
();
let
seq_
hash_elapsed
=
start
.elapsed
();
let
overlap_scores
=
self
let
overlap_scores
=
self
.indexer
.indexer
...
@@ -450,21 +457,6 @@ where
...
@@ -450,21 +457,6 @@ where
.await
?
;
.await
?
;
let
find_matches_elapsed
=
start
.elapsed
();
let
find_matches_elapsed
=
start
.elapsed
();
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let
maybe_seq_hashes
=
tracing
::
info_span!
(
"kv_router.compute_seq_hashes"
)
.in_scope
(||
{
self
.kv_router_config
.compute_seq_hashes_for_tracking
(
tokens
,
self
.block_size
,
router_config_override
,
BlockHashOptions
{
block_mm_infos
,
lora_name
:
lora_name
.as_deref
(),
is_eagle
:
Some
(
self
.is_eagle
),
},
)
});
let
seq_hash_elapsed
=
start
.elapsed
();
let
response
=
self
let
response
=
self
.scheduler
.scheduler
.schedule
(
.schedule
(
...
@@ -486,8 +478,8 @@ where
...
@@ -486,8 +478,8 @@ where
if
let
Some
(
m
)
=
metrics
::
RoutingOverheadMetrics
::
get
()
{
if
let
Some
(
m
)
=
metrics
::
RoutingOverheadMetrics
::
get
()
{
m
.observe
(
m
.observe
(
hash_elapsed
,
hash_elapsed
,
find_matches_elapsed
,
seq_hash_elapsed
,
seq_hash_elapsed
,
find_matches_elapsed
,
total_elapsed
,
total_elapsed
,
);
);
}
}
...
@@ -496,9 +488,9 @@ where
...
@@ -496,9 +488,9 @@ where
tracing
::
info!
(
tracing
::
info!
(
isl_tokens
,
isl_tokens
,
hash_us
=
hash_elapsed
.as_micros
()
as
u64
,
hash_us
=
hash_elapsed
.as_micros
()
as
u64
,
find_matches_us
=
(
find_matches
_elapsed
-
hash_elapsed
)
.as_micros
()
as
u64
,
seq_hash_us
=
(
seq_hash
_elapsed
-
hash_elapsed
)
.as_micros
()
as
u64
,
seq_hash_us
=
(
seq_hash_elapsed
-
find_matches
_elapsed
)
.as_micros
()
as
u64
,
find_matches_us
=
(
find_matches_elapsed
-
seq_hash
_elapsed
)
.as_micros
()
as
u64
,
schedule_us
=
(
total_elapsed
-
seq_hash
_elapsed
)
.as_micros
()
as
u64
,
schedule_us
=
(
total_elapsed
-
find_matches
_elapsed
)
.as_micros
()
as
u64
,
total_us
=
total_elapsed
.as_micros
()
as
u64
,
total_us
=
total_elapsed
.as_micros
()
as
u64
,
"find_best_match completed"
"find_best_match completed"
);
);
...
@@ -524,16 +516,18 @@ where
...
@@ -524,16 +516,18 @@ where
router_config_override
:
Option
<&
RouterConfigOverride
>
,
router_config_override
:
Option
<&
RouterConfigOverride
>
,
)
{
)
{
let
isl_tokens
=
tokens
.len
();
let
isl_tokens
=
tokens
.len
();
let
hash_options
=
BlockHashOptions
{
block_mm_infos
,
lora_name
:
lora_name
.as_deref
(),
is_eagle
:
Some
(
self
.is_eagle
),
};
let
maybe_seq_hashes
=
self
.kv_router_config
.compute_seq_hashes_for_tracking
(
let
maybe_seq_hashes
=
self
.kv_router_config
.compute_seq_hashes_for_tracking
(
tokens
,
tokens
,
self
.block_size
,
self
.block_size
,
router_config_override
,
router_config_override
,
BlockHashOptions
{
hash_options
,
block_mm_infos
,
None
,
lora_name
:
lora_name
.as_deref
(),
is_eagle
:
Some
(
self
.is_eagle
),
},
);
);
if
let
Err
(
e
)
=
self
if
let
Err
(
e
)
=
self
...
@@ -615,28 +609,23 @@ where
...
@@ -615,28 +609,23 @@ where
lora_name
:
Option
<&
str
>
,
lora_name
:
Option
<&
str
>
,
)
->
Result
<
Vec
<
PotentialLoad
>>
{
)
->
Result
<
Vec
<
PotentialLoad
>>
{
let
isl_tokens
=
tokens
.len
();
let
isl_tokens
=
tokens
.len
();
let
block_hashes
=
compute_block_hash_for_seq
(
let
hash_options
=
BlockHashOptions
{
tokens
,
block_mm_infos
,
self
.block_size
,
lora_name
,
BlockHashOptions
{
is_eagle
:
Some
(
self
.is_eagle
),
block_mm_infos
,
};
lora_name
,
let
block_hashes
=
compute_block_hash_for_seq
(
tokens
,
self
.block_size
,
hash_options
);
is_eagle
:
Some
(
self
.is_eagle
),
},
);
let
overlap_scores
=
self
.indexer
.find_matches
(
block_hashes
.clone
())
.await
?
;
let
maybe_seq_hashes
=
self
.kv_router_config
.compute_seq_hashes_for_tracking
(
let
maybe_seq_hashes
=
self
.kv_router_config
.compute_seq_hashes_for_tracking
(
tokens
,
tokens
,
self
.block_size
,
self
.block_size
,
router_config_override
,
router_config_override
,
BlockHashOptions
{
hash_options
,
block_mm_infos
,
Some
(
&
block_hashes
),
lora_name
,
is_eagle
:
Some
(
self
.is_eagle
),
},
);
);
let
overlap_scores
=
self
.indexer
.find_matches
(
block_hashes
)
.await
?
;
Ok
(
self
Ok
(
self
.scheduler
.scheduler
.get_potential_loads
(
maybe_seq_hashes
,
isl_tokens
,
overlap_scores
))
.get_potential_loads
(
maybe_seq_hashes
,
isl_tokens
,
overlap_scores
))
...
...
lib/llm/src/kv_router/metrics.rs
View file @
1368ccd6
...
@@ -263,26 +263,26 @@ impl RoutingOverheadMetrics {
...
@@ -263,26 +263,26 @@ impl RoutingOverheadMetrics {
pub
fn
observe
(
pub
fn
observe
(
&
self
,
&
self
,
hash_elapsed
:
Duration
,
hash_elapsed
:
Duration
,
find_matches_elapsed
:
Duration
,
seq_hash_elapsed
:
Duration
,
seq_hash_elapsed
:
Duration
,
find_matches_elapsed
:
Duration
,
total_elapsed
:
Duration
,
total_elapsed
:
Duration
,
)
{
)
{
self
.block_hashing
self
.block_hashing
.observe
(
hash_elapsed
.as_secs_f64
()
*
1000.0
);
.observe
(
hash_elapsed
.as_secs_f64
()
*
1000.0
);
self
.seq_hashing
.observe
(
seq_hash_elapsed
.saturating_sub
(
hash_elapsed
)
.as_secs_f64
()
*
1000.0
);
self
.indexer_find_matches
.observe
(
self
.indexer_find_matches
.observe
(
find_matches_elapsed
find_matches_elapsed
.saturating_sub
(
hash_elapsed
)
.saturating_sub
(
seq_
hash_elapsed
)
.as_secs_f64
()
.as_secs_f64
()
*
1000.0
,
*
1000.0
,
);
);
self
.s
eq_hash
ing
.observe
(
self
.s
chedul
ing
.observe
(
seq_hash
_elapsed
total
_elapsed
.saturating_sub
(
find_matches_elapsed
)
.saturating_sub
(
find_matches_elapsed
)
.as_secs_f64
()
.as_secs_f64
()
*
1000.0
,
*
1000.0
,
);
);
self
.scheduling
.observe
(
total_elapsed
.saturating_sub
(
seq_hash_elapsed
)
.as_secs_f64
()
*
1000.0
);
self
.total
.observe
(
total_elapsed
.as_secs_f64
()
*
1000.0
);
self
.total
.observe
(
total_elapsed
.as_secs_f64
()
*
1000.0
);
}
}
}
}
...
@@ -557,7 +557,7 @@ dynamo_frontend_router_queue_pending_requests{worker_type=\"decode\"} 5
...
@@ -557,7 +557,7 @@ dynamo_frontend_router_queue_pending_requests{worker_type=\"decode\"} 5
total
:
make
(
"test_total_ms"
),
total
:
make
(
"test_total_ms"
),
};
};
// Out-of-order durations: each phase < previous (would panic without saturating_sub)
// Out-of-order
cumulative
durations: each phase < previous (would panic without saturating_sub)
metrics
.observe
(
metrics
.observe
(
Duration
::
from_millis
(
10
),
Duration
::
from_millis
(
10
),
Duration
::
from_millis
(
5
),
Duration
::
from_millis
(
5
),
...
...
lib/llm/src/kv_router/worker_query.rs
View file @
1368ccd6
...
@@ -518,8 +518,8 @@ impl WorkerQueryClient {
...
@@ -518,8 +518,8 @@ impl WorkerQueryClient {
events
.len
(),
events
.len
(),
last_event_id
last_event_id
);
);
for
event
in
&
events
{
for
event
in
events
{
self
.indexer
.apply_event
(
event
.clone
()
)
.await
;
self
.indexer
.apply_event
(
event
)
.await
;
}
}
new_cursor
=
new_cursor
.advance_to
(
last_event_id
);
new_cursor
=
new_cursor
.advance_to
(
last_event_id
);
successful_response
=
true
;
successful_response
=
true
;
...
...
lib/mocker/src/replay/router/offline.rs
View file @
1368ccd6
...
@@ -348,6 +348,7 @@ impl OfflineReplayRouter {
...
@@ -348,6 +348,7 @@ impl OfflineReplayRouter {
self
.block_size
,
self
.block_size
,
None
,
None
,
BlockHashOptions
::
default
(),
BlockHashOptions
::
default
(),
None
,
)
)
};
};
(
overlaps
,
token_seq
)
(
overlaps
,
token_seq
)
...
@@ -359,6 +360,7 @@ impl OfflineReplayRouter {
...
@@ -359,6 +360,7 @@ impl OfflineReplayRouter {
self
.block_size
,
self
.block_size
,
None
,
None
,
BlockHashOptions
::
default
(),
BlockHashOptions
::
default
(),
None
,
);
);
(
overlaps
,
token_seq
)
(
overlaps
,
token_seq
)
}
}
...
...
lib/mocker/src/replay/router/online.rs
View file @
1368ccd6
...
@@ -188,6 +188,7 @@ impl KvReplayRouter {
...
@@ -188,6 +188,7 @@ impl KvReplayRouter {
self
.block_size
,
self
.block_size
,
None
,
None
,
BlockHashOptions
::
default
(),
BlockHashOptions
::
default
(),
None
,
);
);
let
response
=
self
let
response
=
self
.scheduler
.scheduler
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment