Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
d29f7fcc
Commit
d29f7fcc
authored
Mar 24, 2025
by
Hongkuan Zhou
Committed by
GitHub
Mar 24, 2025
Browse files
feat: conditional disagg based on prefill queue size (#303)
parent
d7165149
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
4 deletions
+45
-4
examples/llm/components/disagg_router.py
examples/llm/components/disagg_router.py
+12
-3
examples/llm/components/worker.py
examples/llm/components/worker.py
+10
-1
examples/llm/configs/disagg.yaml
examples/llm/configs/disagg.yaml
+1
-0
examples/llm/configs/disagg_router.yaml
examples/llm/configs/disagg_router.yaml
+2
-0
examples/llm/utils/nats_queue.py
examples/llm/utils/nats_queue.py
+13
-0
examples/llm/utils/vllm.py
examples/llm/utils/vllm.py
+7
-0
No files found.
examples/llm/components/disagg_router.py
View file @
d29f7fcc
...
@@ -23,14 +23,23 @@ class PyDisaggregatedRouter:
...
@@ -23,14 +23,23 @@ class PyDisaggregatedRouter:
runtime
,
runtime
,
served_model_name
,
served_model_name
,
max_local_prefill_length
=
1000
,
max_local_prefill_length
=
1000
,
max_prefill_queue_size
=
2
,
):
):
self
.
runtime
=
runtime
self
.
runtime
=
runtime
self
.
served_model_name
=
served_model_name
self
.
served_model_name
=
served_model_name
self
.
max_local_prefill_length
=
max_local_prefill_length
self
.
max_local_prefill_length
=
max_local_prefill_length
self
.
max_prefill_queue_size
=
max_prefill_queue_size
def
prefill_remote
(
self
,
prompt_length
:
int
,
prefix_hit_rate
:
float
):
def
prefill_remote
(
self
,
prompt_length
:
int
,
prefix_hit_rate
:
float
,
queue_size
:
int
):
absolute_prefill_length
=
int
(
prompt_length
*
(
1
-
prefix_hit_rate
))
absolute_prefill_length
=
int
(
prompt_length
*
(
1
-
prefix_hit_rate
))
# TODO: consider size of each request in the queue when making the decision
decision
=
(
absolute_prefill_length
>
self
.
max_local_prefill_length
and
queue_size
<
self
.
max_prefill_queue_size
)
vllm_logger
.
info
(
vllm_logger
.
info
(
f
"Remote prefill:
{
absolute_prefill_length
>
self
.
max_local_prefill_length
}
(prefill length:
{
absolute_prefill_length
}
/
{
prompt_length
}
)"
f
"Remote prefill:
{
decision
}
(prefill length:
{
absolute_prefill_length
}
/
{
prompt_length
}
, prefill queue size:
{
queue_size
}
/
{
self
.
max_prefill_queue_size
}
)"
)
)
return
absolute_prefill_length
>
self
.
max_local_prefill_length
return
decision
examples/llm/components/worker.py
View file @
d29f7fcc
...
@@ -125,6 +125,7 @@ class VllmWorker:
...
@@ -125,6 +125,7 @@ class VllmWorker:
runtime
,
runtime
,
self
.
model_name
,
self
.
model_name
,
max_local_prefill_length
=
self
.
engine_args
.
max_local_prefill_length
,
max_local_prefill_length
=
self
.
engine_args
.
max_local_prefill_length
,
max_prefill_queue_size
=
self
.
engine_args
.
max_prefill_queue_size
,
)
)
else
:
else
:
self
.
disaggregated_router
=
None
self
.
disaggregated_router
=
None
...
@@ -148,9 +149,17 @@ class VllmWorker:
...
@@ -148,9 +149,17 @@ class VllmWorker:
@
dynamo_endpoint
()
@
dynamo_endpoint
()
async
def
generate
(
self
,
request
:
vLLMGenerateRequest
):
async
def
generate
(
self
,
request
:
vLLMGenerateRequest
):
# TODO: consider prefix hit when deciding prefill locally or remotely
# TODO: consider prefix hit when deciding prefill locally or remotely
if
self
.
disaggregated_router
is
not
None
:
if
self
.
disaggregated_router
is
not
None
:
async
with
PrefillQueue
.
get_instance
(
nats_server
=
self
.
_prefill_queue_nats_server
,
stream_name
=
self
.
_prefill_queue_stream_name
,
)
as
prefill_queue
:
prefill_queue_size
=
await
prefill_queue
.
get_queue_size
()
disagg_router_decision
=
self
.
disaggregated_router
.
prefill_remote
(
disagg_router_decision
=
self
.
disaggregated_router
.
prefill_remote
(
len
(
request
.
engine_prompt
[
"prompt_token_ids"
]),
request
.
prefix_hit_rate
len
(
request
.
engine_prompt
[
"prompt_token_ids"
]),
request
.
prefix_hit_rate
,
prefill_queue_size
,
)
)
else
:
else
:
# always prefill remotely if no disaggregated router is provided
# always prefill remotely if no disaggregated router is provided
...
...
examples/llm/configs/disagg.yaml
View file @
d29f7fcc
...
@@ -30,6 +30,7 @@ VllmWorker:
...
@@ -30,6 +30,7 @@ VllmWorker:
remote-prefill
:
true
remote-prefill
:
true
conditional-disagg
:
true
conditional-disagg
:
true
max-local-prefill-length
:
10
max-local-prefill-length
:
10
max-prefill-queue-size
:
2
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
resources
:
resources
:
...
...
examples/llm/configs/disagg_router.yaml
View file @
d29f7fcc
...
@@ -36,6 +36,8 @@ VllmWorker:
...
@@ -36,6 +36,8 @@ VllmWorker:
max-model-len
:
16384
max-model-len
:
16384
max-num-batched-tokens
:
16384
max-num-batched-tokens
:
16384
conditional-disagg
:
true
conditional-disagg
:
true
max-local-prefill-length
:
10
max-prefill-queue-size
:
2
tensor-parallel-size
:
1
tensor-parallel-size
:
1
router
:
kv
router
:
kv
enable-prefix-caching
:
true
enable-prefix-caching
:
true
...
...
examples/llm/utils/nats_queue.py
View file @
d29f7fcc
...
@@ -140,3 +140,16 @@ class NATSQueue:
...
@@ -140,3 +140,16 @@ class NATSQueue:
return
None
return
None
except
NatsError
as
e
:
except
NatsError
as
e
:
raise
RuntimeError
(
f
"Failed to dequeue task:
{
e
}
"
)
raise
RuntimeError
(
f
"Failed to dequeue task:
{
e
}
"
)
async
def
get_queue_size
(
self
)
->
int
:
"""Get the number of messages currently in the queue"""
await
self
.
ensure_connection
()
try
:
# Get consumer info to get pending messages count
consumer_info
=
await
self
.
_js
.
consumer_info
(
# type: ignore
self
.
_stream_name
,
"worker-group"
)
# Return number of pending messages (real-time queue size)
return
consumer_info
.
num_pending
except
NatsError
as
e
:
raise
RuntimeError
(
f
"Failed to get queue size:
{
e
}
"
)
examples/llm/utils/vllm.py
View file @
d29f7fcc
...
@@ -45,6 +45,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
...
@@ -45,6 +45,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default
=
1000
,
default
=
1000
,
help
=
"Maximum length of local prefill"
,
help
=
"Maximum length of local prefill"
,
)
)
parser
.
add_argument
(
"--max-prefill-queue-size"
,
type
=
int
,
default
=
3
,
help
=
"Do not send remote prefill requests (prefill locally) if the queue size is greater than this value"
,
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
(
vllm_args
)
args
=
parser
.
parse_args
(
vllm_args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
...
@@ -52,4 +58,5 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
...
@@ -52,4 +58,5 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
engine_args
.
remote_prefill
=
args
.
remote_prefill
engine_args
.
remote_prefill
=
args
.
remote_prefill
engine_args
.
conditional_disagg
=
args
.
conditional_disagg
engine_args
.
conditional_disagg
=
args
.
conditional_disagg
engine_args
.
max_local_prefill_length
=
args
.
max_local_prefill_length
engine_args
.
max_local_prefill_length
=
args
.
max_local_prefill_length
engine_args
.
max_prefill_queue_size
=
args
.
max_prefill_queue_size
return
engine_args
return
engine_args
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment