Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
2be5e8f5
"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "093efb9a81af77aac0f396b157f11cd5a197fa74"
Unverified
Commit
2be5e8f5
authored
Apr 30, 2025
by
Yan Ru Pei
Committed by
GitHub
Apr 30, 2025
Browse files
chore: reduce code repetition in processor (#919)
parent
0086ebc6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
87 additions
and
91 deletions
+87
-91
examples/llm/components/kv_router.py
examples/llm/components/kv_router.py
+64
-10
examples/llm/components/processor.py
examples/llm/components/processor.py
+21
-79
examples/llm/configs/agg_router.yaml
examples/llm/configs/agg_router.yaml
+1
-1
examples/llm/configs/disagg_router.yaml
examples/llm/configs/disagg_router.yaml
+1
-1
No files found.
examples/llm/components/kv_router.py
View file @
2be5e8f5
...
...
@@ -18,17 +18,19 @@ import argparse
import
logging
import
random
from
argparse
import
Namespace
from
typing
import
AsyncIterator
from
typing
import
AsyncIterator
,
Tuple
from
components.worker
import
VllmWorker
from
utils.logging
import
check_required_workers
from
utils.protocol
import
Tokens
from
utils.vllm
import
RouterType
from
dynamo.llm
import
AggregatedMetrics
,
KvIndexer
,
KvMetricsAggregator
,
OverlapScores
from
dynamo.sdk
import
async_on_start
,
depends
,
dynamo_context
,
dynamo_endpoint
,
service
from
dynamo.sdk.lib.config
import
ServiceConfig
WorkerId
=
str
fallback_msg
=
"Will fallback to random routing."
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -60,6 +62,12 @@ def parse_args(service_name, prefix) -> Namespace:
default
=
False
,
help
=
"Whether to use custom router or not"
,
)
parser
.
add_argument
(
"--router"
,
type
=
str
,
default
=
"kv"
,
help
=
"The router type"
,
)
config
=
ServiceConfig
.
get_instance
()
config_args
=
config
.
as_args
(
service_name
,
prefix
=
prefix
)
args
=
parser
.
parse_args
(
config_args
)
...
...
@@ -101,11 +109,14 @@ class Router:
.
client
()
)
self
.
router_type
=
self
.
args
.
router
await
check_required_workers
(
self
.
workers_client
,
self
.
args
.
min_workers
)
kv_listener
=
self
.
runtime
.
namespace
(
"dynamo"
).
component
(
"VllmWorker"
)
await
kv_listener
.
create_service
()
self
.
indexer
=
KvIndexer
(
kv_listener
,
self
.
args
.
block_size
)
if
self
.
router_type
==
RouterType
.
KV
:
self
.
indexer
=
KvIndexer
(
kv_listener
,
self
.
args
.
block_size
)
self
.
metrics_aggregator
=
KvMetricsAggregator
(
kv_listener
)
logger
.
info
(
"KV Router initialized"
)
...
...
@@ -182,7 +193,8 @@ class Router:
f
"Formula for
{
worker_id
}
:
{
worker_logits
[
worker_id
]:.
3
f
}
= 2.0 *
{
score
:.
3
f
}
-
{
gpu_cache_usage
:.
3
f
}
-
{
normalized_waiting
:.
3
f
}
"
)
if
not
worker_logits
or
all
(
logit
==
0
for
logit
in
worker_logits
.
values
()):
if
not
worker_logits
or
not
any
(
worker_logits
.
values
()):
logger
.
warning
(
f
"All worker logits are zero.
{
fallback_msg
}
."
)
return
""
,
0.0
# Select the worker with the highest logit
...
...
@@ -211,8 +223,47 @@ class Router:
return
best_worker_id
,
worker_scores
.
get
(
best_worker_id
,
0.0
)
def
_get_underloaded_worker
(
self
,
metrics
:
AggregatedMetrics
|
None
):
if
not
metrics
:
logger
.
warning
(
f
"Cannot get metrics.
{
fallback_msg
}
"
)
return
""
,
0.0
kv_load
=
{
endpoint
.
worker_id
:
getattr
(
endpoint
,
"gpu_cache_usage_perc"
,
0.0
)
for
endpoint
in
metrics
.
endpoints
}
if
not
kv_load
or
not
any
(
kv_load
.
values
()):
logger
.
warning
(
f
"All KV loads are zero.
{
fallback_msg
}
"
)
return
""
,
0.0
min_load
=
min
(
kv_load
.
values
())
min_load_workers
=
[
worker_id
for
worker_id
,
load
in
kv_load
.
items
()
if
load
==
min_load
]
best_worker_id
=
random
.
choice
(
min_load_workers
)
logger
.
info
(
f
"Selected worker:
{
best_worker_id
}
, KV load:
{
kv_load
[
best_worker_id
]:.
3
f
}
"
)
return
best_worker_id
,
kv_load
[
best_worker_id
]
@
dynamo_endpoint
()
async
def
generate
(
self
,
request
:
Tokens
)
->
AsyncIterator
[
WorkerId
]:
async
def
generate
(
self
,
request
:
Tokens
)
->
AsyncIterator
[
Tuple
[
WorkerId
,
float
]]:
metrics
=
await
self
.
metrics_aggregator
.
get_metrics
()
# Quick return for KV_LOAD mode
if
self
.
router_type
==
RouterType
.
KV_LOAD
:
try
:
yield
self
.
_get_underloaded_worker
(
metrics
)
except
Exception
as
e
:
logger
.
exception
(
f
"Error finding underloaded worker:
{
e
}
.
{
fallback_msg
}
"
)
yield
""
,
0.0
return
# Existing KV routing logic
lora_id
=
0
try
:
scores
=
await
self
.
indexer
.
find_matches_for_request
(
...
...
@@ -220,14 +271,17 @@ class Router:
)
except
Exception
as
e
:
scores
=
{}
logger
.
exception
(
f
"Error finding matches:
{
e
}
"
)
logger
.
exception
(
f
"Error finding matches:
{
e
}
.
{
fallback_msg
}
"
)
yield
""
,
0.0
return
metrics
=
await
self
.
metrics_aggregator
.
get_metrics
()
worker_id
,
prefix_hit_rate
=
self
.
_cost_function
(
scores
,
metrics
,
len
(
request
.
tokens
)
)
logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
with estimated prefix hit rate:
{
prefix_hit_rate
}
"
)
yield
f
"
{
worker_id
}
_
{
prefix_hit_rate
}
"
if
worker_id
:
logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
with estimated prefix hit rate:
{
prefix_hit_rate
}
"
)
yield
worker_id
,
prefix_hit_rate
examples/llm/components/processor.py
View file @
2be5e8f5
...
...
@@ -95,7 +95,8 @@ class Processor(ProcessMixIn):
.
client
()
)
if
self
.
engine_args
.
router
==
RouterType
.
KV
:
self
.
use_router
=
self
.
engine_args
.
router
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
)
if
self
.
use_router
:
router_ns
,
router_name
=
Router
.
dynamo_address
()
# type: ignore
self
.
router_client
=
(
await
runtime
.
namespace
(
router_ns
)
...
...
@@ -116,22 +117,6 @@ class Processor(ProcessMixIn):
{
"router"
:
self
.
engine_args
.
router
},
)
async
def
_get_kv_load
(
self
):
metrics
=
await
self
.
metrics_aggregator
.
get_metrics
()
kv_load
=
{}
for
endpoint
in
metrics
.
endpoints
:
worker_id
=
endpoint
.
worker_id
kv_load
[
worker_id
]
=
getattr
(
endpoint
,
"gpu_cache_usage_perc"
,
0.0
)
return
kv_load
async
def
_get_pending_requests
(
self
):
metrics
=
await
self
.
metrics_aggregator
.
get_metrics
()
pending_requests
=
{}
for
endpoint
in
metrics
.
endpoints
:
worker_id
=
endpoint
.
worker_id
pending_requests
[
worker_id
]
=
getattr
(
endpoint
,
"num_requests_waiting"
,
0
)
return
pending_requests
async
def
_generate
(
self
,
raw_request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
...
...
@@ -146,81 +131,38 @@ class Processor(ProcessMixIn):
engine_prompt
,
sampling_params
,
)
=
await
self
.
_parse_raw_request
(
raw_request
)
# TODO: queue request at processor when engines are full
router_mode
=
(
await
self
.
etcd_kv_cache
.
get
(
"router"
)).
decode
()
if
router_mode
==
RouterType
.
KV
:
prefix_hit_rate
=
0.0
if
self
.
use_router
:
router_generator
=
await
self
.
router_client
.
generate
(
Tokens
(
tokens
=
engine_prompt
[
"prompt_token_ids"
]).
model_dump_json
()
)
decision
=
await
router_generator
.
__anext__
()
decision
=
decision
.
data
()
worker_id
,
prefix_hit_rate
=
decision
.
split
(
"_"
)
worker_id
,
prefix_hit_rate
=
decision
.
data
()
prefix_hit_rate
=
float
(
prefix_hit_rate
)
logger
.
info
(
f
"Worker ID:
{
worker_id
}
with estimated prefix hit rate:
{
prefix_hit_rate
}
"
)
# Create request object once with default prefix_hit_rate
request_obj
=
vLLMGenerateRequest
(
engine_prompt
=
engine_prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
prefix_hit_rate
=
prefix_hit_rate
,
).
model_dump_json
()
if
self
.
use_router
:
if
worker_id
==
""
:
engine_generator
=
await
self
.
worker_client
.
generate
(
vLLMGenerateRequest
(
engine_prompt
=
engine_prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
prefix_hit_rate
=
prefix_hit_rate
,
).
model_dump_json
()
)
engine_generator
=
await
self
.
worker_client
.
generate
(
request_obj
)
else
:
engine_generator
=
await
self
.
worker_client
.
direct
(
vLLMGenerateRequest
(
engine_prompt
=
engine_prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
prefix_hit_rate
=
prefix_hit_rate
,
).
model_dump_json
(),
int
(
worker_id
),
request_obj
,
int
(
worker_id
)
)
elif
router_mode
==
RouterType
.
RANDOM
:
engine_generator
=
await
self
.
worker_client
.
generate
(
vLLMGenerateRequest
(
engine_prompt
=
engine_prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
).
model_dump_json
()
)
engine_generator
=
await
self
.
worker_client
.
generate
(
request_obj
)
elif
router_mode
==
RouterType
.
ROUND_ROBIN
:
engine_generator
=
await
self
.
worker_client
.
round_robin
(
vLLMGenerateRequest
(
engine_prompt
=
engine_prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
).
model_dump_json
()
)
elif
router_mode
==
RouterType
.
KV_LOAD
:
# route to worker with least kv load
# TODO: move the router to a separate file and clean up processor.py
try
:
kv_load
=
await
self
.
_get_kv_load
()
best_worker_id
=
min
(
kv_load
,
key
=
kv_load
.
get
)
logger
.
info
(
f
"Routing to worker
{
best_worker_id
}
(kv load:
{
kv_load
}
)"
)
engine_generator
=
await
self
.
worker_client
.
direct
(
vLLMGenerateRequest
(
engine_prompt
=
engine_prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
).
model_dump_json
(),
int
(
best_worker_id
),
)
except
Exception
as
e
:
logger
.
info
(
f
"Error finding worker with least kv load:
{
e
}
, fallback to random"
)
engine_generator
=
await
self
.
worker_client
.
generate
(
vLLMGenerateRequest
(
engine_prompt
=
engine_prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
).
model_dump_json
()
)
engine_generator
=
await
self
.
worker_client
.
round_robin
(
request_obj
)
output
=
self
.
_generate_responses
(
engine_generator
,
request_type
)
async
for
response
in
await
self
.
_stream_response
(
...
...
examples/llm/configs/agg_router.yaml
View file @
2be5e8f5
...
...
@@ -29,7 +29,7 @@ Processor:
Router
:
min-workers
:
1
common-configs
:
[
model
]
common-configs
:
[
model
,
router
]
VllmWorker
:
enforce-eager
:
true
...
...
examples/llm/configs/disagg_router.yaml
View file @
2be5e8f5
...
...
@@ -29,7 +29,7 @@ Processor:
Router
:
min-workers
:
1
common-configs
:
[
model
]
common-configs
:
[
model
,
router
]
VllmWorker
:
max-num-batched-tokens
:
16384
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment