Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
a544d823
Commit
a544d823
authored
Mar 25, 2025
by
Yan Ru Pei
Committed by
GitHub
Mar 25, 2025
Browse files
chore: more Pythonic kv router cleanups in examples (#396)
parent
cce0c028
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
48 deletions
+29
-48
examples/llm/components/kv_router.py
examples/llm/components/kv_router.py
+29
-48
No files found.
examples/llm/components/kv_router.py
View file @
a544d823
...
...
@@ -83,6 +83,12 @@ class Router:
vllm_logger
.
info
(
"Initializing Custom Router"
)
self
.
args
=
parse_args
(
self
.
__class__
.
__name__
,
""
)
self
.
default_metrics
=
{
"gpu_cache_usage_perc"
:
0.0
,
"num_requests_waiting"
:
0.0
,
"gpu_prefix_cache_hit_rate"
:
0.0
,
}
@
async_on_start
async
def
async_init
(
self
):
self
.
runtime
=
dynamo_context
[
"runtime"
]
...
...
@@ -140,21 +146,13 @@ class Router:
)
worker_metrics
=
{}
# pull metrics for each worker
max_waiting
=
0.0
if
metrics
:
for
endpoint
in
metrics
.
endpoints
:
worker_id
=
endpoint
.
worker_id
worker_metrics
[
worker_id
]
=
{
"gpu_cache_usage_perc"
:
getattr
(
endpoint
,
"gpu_cache_usage_perc"
,
0.0
),
"num_requests_waiting"
:
getattr
(
endpoint
,
"num_requests_waiting"
,
0.0
),
"gpu_prefix_cache_hit_rate"
:
getattr
(
endpoint
,
"gpu_prefix_cache_hit_rate"
,
0.0
),
key
:
getattr
(
endpoint
,
key
,
self
.
default_metrics
[
key
])
for
key
in
self
.
default_metrics
.
keys
()
}
max_waiting
=
max
(
max_waiting
,
worker_metrics
[
worker_id
][
"num_requests_waiting"
]
...
...
@@ -168,14 +166,8 @@ class Router:
for
worker_id
in
worker_ids
:
# Use default values if worker not in scores or metrics
score
=
worker_scores
.
get
(
worker_id
,
0.0
)
metrics_dict
=
worker_metrics
.
get
(
worker_id
,
{
"gpu_cache_usage_perc"
:
0.0
,
"num_requests_waiting"
:
0.0
,
"gpu_prefix_cache_hit_rate"
:
0.0
,
},
)
metrics_dict
=
worker_metrics
.
get
(
worker_id
,
self
.
default_metrics
)
gpu_cache_usage
=
metrics_dict
[
"gpu_cache_usage_perc"
]
normalized_waiting
=
(
metrics_dict
[
"num_requests_waiting"
]
/
max_waiting
...
...
@@ -185,15 +177,13 @@ class Router:
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits
[
worker_id
]
=
(
2
*
score
-
metrics_dict
[
"gpu_cache_usage_perc"
]
-
normalized_waiting
)
worker_logits
[
worker_id
]
=
2
*
score
-
gpu_cache_usage
-
normalized_waiting
vllm_logger
.
info
(
f
"Formula for
{
worker_id
}
:
{
worker_logits
[
worker_id
]:.
3
f
}
= 2.0 *
{
score
:.
3
f
}
-
{
metrics_dict
[
'
gpu_cache_usage
_perc'
]
:.
3
f
}
-
{
normalized_waiting
:.
3
f
}
"
f
"Formula for
{
worker_id
}
:
{
worker_logits
[
worker_id
]:.
3
f
}
= 2.0 *
{
score
:.
3
f
}
-
{
gpu_cache_usage
:.
3
f
}
-
{
normalized_waiting
:.
3
f
}
"
)
if
not
worker_logits
or
all
(
logit
==
0
for
logit
in
worker_logits
.
values
()):
return
""
return
""
,
0.0
# Select the worker with the highest logit
max_logit
=
max
(
worker_logits
.
values
())
...
...
@@ -204,30 +194,26 @@ class Router:
# Log the metrics for the selected worker
if
best_worker_id
:
vllm_logger
.
info
(
f
"Selected worker:
{
best_worker_id
}
, logit:
{
worker_logits
[
best_worker_id
]:.
3
f
}
"
)
vllm_logger
.
info
(
f
"Score:
{
scores
.
scores
.
get
(
best_worker_id
,
0.0
)
if
scores
else
0.0
:.
3
f
}
"
)
metrics_dict
=
worker_metrics
.
get
(
best_worker_id
,
self
.
default_metrics
)
metrics_dict
=
worker_metrics
.
get
(
best_worker_id
,
{})
vllm_logger
.
info
(
f
"GPU Cache Hit Rate:
{
metrics_dict
.
get
(
'gpu_prefix_cache_hit_rate'
,
0.0
):.
3
f
}
"
)
vllm_logger
.
info
(
f
"GPU Cache Usage:
{
metrics_dict
.
get
(
'gpu_cache_usage_perc'
,
0.0
):.
3
f
}
"
)
vllm_logger
.
info
(
f
"Requests Waiting:
{
metrics_dict
.
get
(
'num_requests_waiting'
,
0.0
)
/
max_waiting
if
max_waiting
>
0
else
0.0
:.
3
f
}
"
)
# Create log messages
log_messages
=
[
f
"Selected worker:
{
best_worker_id
}
, logit:
{
worker_logits
[
best_worker_id
]:.
3
f
}
"
,
f
"Score:
{
scores
.
scores
.
get
(
best_worker_id
,
0.0
)
if
scores
else
0.0
:.
3
f
}
"
,
f
"GPU Cache Hit Rate:
{
metrics_dict
[
'gpu_prefix_cache_hit_rate'
]:.
3
f
}
"
,
f
"GPU Cache Usage:
{
metrics_dict
[
'gpu_cache_usage_perc'
]:.
3
f
}
"
,
f
"Requests Waiting:
{
metrics_dict
[
'num_requests_waiting'
]
}
"
,
]
# Log to vllm_logger
for
message
in
log_messages
:
vllm_logger
.
info
(
message
)
return
best_worker_id
,
worker_scores
.
get
(
best_worker_id
,
0.0
)
@
dynamo_endpoint
()
async
def
generate
(
self
,
request
:
Tokens
)
->
AsyncIterator
[
WorkerId
]:
lora_id
=
0
worker_id
=
""
try
:
scores
=
await
self
.
indexer
.
find_matches_for_request
(
request
.
tokens
,
lora_id
...
...
@@ -236,17 +222,12 @@ class Router:
scores
=
{}
vllm_logger
.
exception
(
f
"Error finding matches:
{
e
}
"
)
token_length
=
len
(
request
.
tokens
)
metrics
=
await
self
.
metrics_aggregator
.
get_metrics
()
schedule_result
=
self
.
_cost_function
(
scores
,
metrics
,
token_length
)
if
schedule_result
==
""
:
worker_id
=
""
prefix_hit_rate
=
0.0
else
:
worker_id
,
prefix_hit_rate
=
schedule_result
worker_id
,
prefix_hit_rate
=
self
.
_cost_function
(
scores
,
metrics
,
len
(
request
.
tokens
)
)
vllm_logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
with estimated prefix hit rate:
{
prefix_hit_rate
}
"
)
yield
f
"
{
worker_id
}
_
{
prefix_hit_rate
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment