Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8dcc159
Unverified
Commit
c8dcc159
authored
Jun 04, 2025
by
jmswen
Committed by
GitHub
Jun 04, 2025
Browse files
Allow AsyncLLMEngine.generate to target a specific DP rank (#19102)
Signed-off-by:
Jon Swenson
<
jmswen@gmail.com
>
parent
8f4ffbd3
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
97 additions
and
5 deletions
+97
-5
examples/online_serving/multi_instance_data_parallel.py
examples/online_serving/multi_instance_data_parallel.py
+58
-0
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+2
-1
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+1
-0
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+1
-0
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+5
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+11
-1
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+1
-0
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+4
-1
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+12
-2
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+2
-0
No files found.
examples/online_serving/multi_instance_data_parallel.py
0 → 100644
View file @
c8dcc159
# SPDX-License-Identifier: Apache-2.0
import
asyncio
from
typing
import
Optional
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
"""
To run this example, run the following commands simultaneously with
different CUDA_VISIBLE_DEVICES:
python examples/online_serving/multi_instance_data_parallel.py
vllm serve ibm-research/PowerMoE-3b -dp 2 -dpr 1
\
--data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300
\
--data-parallel-size-local 1 --enforce-eager --headless
Once both instances have completed the handshake, this example will
send a request to the instance with DP rank 1.
"""
async
def
main
():
engine_args
=
AsyncEngineArgs
(
model
=
"ibm-research/PowerMoE-3b"
,
data_parallel_size
=
2
,
dtype
=
"auto"
,
max_model_len
=
2048
,
data_parallel_address
=
"127.0.0.1"
,
data_parallel_rpc_port
=
62300
,
data_parallel_size_local
=
1
,
enforce_eager
=
True
,
)
engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
top_p
=
0.9
,
max_tokens
=
100
,
)
prompt
=
"Who won the 2004 World Series?"
final_output
:
Optional
[
RequestOutput
]
=
None
async
for
output
in
engine_client
.
generate
(
prompt
=
prompt
,
sampling_params
=
sampling_params
,
request_id
=
"abcdef"
,
data_parallel_rank
=
1
,
):
final_output
=
output
if
final_output
:
print
(
final_output
.
outputs
[
0
].
text
)
if
__name__
==
"__main__"
:
asyncio
.
run
(
main
())
tests/tokenization/test_detokenize.py
View file @
c8dcc159
...
...
@@ -70,7 +70,8 @@ def _run_incremental_decode(tokenizer,
None
,
0.0
,
None
,
cache_salt
=
None
)
cache_salt
=
None
,
data_parallel_rank
=
None
)
if
fast
is
None
:
detokenizer
=
IncrementalDetokenizer
.
from_new_request
(
...
...
tests/v1/engine/test_engine_core.py
View file @
c8dcc159
...
...
@@ -42,6 +42,7 @@ def make_request() -> EngineCoreRequest:
arrival_time
=
time
.
time
(),
lora_request
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
)
...
...
tests/v1/engine/test_engine_core_client.py
View file @
c8dcc159
...
...
@@ -56,6 +56,7 @@ def make_request(
arrival_time
=
time
.
time
(),
lora_request
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
)
...
...
tests/v1/engine/test_output_processor.py
View file @
c8dcc159
...
...
@@ -59,6 +59,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
eos_token_id
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
False
,
spaces_between_special_tokens
=
False
,
...
...
@@ -406,6 +407,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
eos_token_id
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
False
,
spaces_between_special_tokens
=
False
,
...
...
@@ -569,6 +571,7 @@ def test_stop_token(include_stop_str_in_output: bool,
eos_token_id
=
eos_token_id
,
lora_request
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
False
,
spaces_between_special_tokens
=
False
,
...
...
@@ -666,6 +669,7 @@ def test_stop_string(include_stop_str_in_output: bool,
eos_token_id
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
sampling_params
=
SamplingParams
(
skip_special_tokens
=
False
,
spaces_between_special_tokens
=
False
,
...
...
@@ -780,6 +784,7 @@ def test_iteration_stats(dummy_test_vectors):
eos_token_id
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
sampling_params
=
SamplingParams
(),
)
for
idx
,
prompt_tokens
in
enumerate
(
dummy_test_vectors
.
prompt_tokens
)
]
...
...
vllm/engine/async_llm_engine.py
View file @
c8dcc159
...
...
@@ -442,6 +442,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
None
:
...
...
...
@@ -456,6 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
None
:
...
...
...
@@ -473,6 +475,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
...
...
@@ -902,6 +905,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
PoolingRequestOutput
],
None
]]:
...
...
...
@@ -917,6 +921,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
PoolingRequestOutput
],
None
]]:
...
...
...
@@ -935,6 +940,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
PoolingRequestOutput
],
None
]:
...
...
@@ -967,6 +973,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
data_parallel_rank
=
data_parallel_rank
,
)
return
stream
.
generator
()
...
...
@@ -980,6 +987,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
...
...
@@ -999,7 +1007,8 @@ class AsyncLLMEngine(EngineClient):
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must
handle this request. Only applicable if DP is enabled.
Yields:
The output `RequestOutput` objects from the LLMEngine
for the request.
...
...
@@ -1057,6 +1066,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
data_parallel_rank
=
data_parallel_rank
,
):
yield
LLMEngine
.
validate_output
(
output
,
RequestOutput
)
except
asyncio
.
CancelledError
:
...
...
vllm/v1/engine/__init__.py
View file @
c8dcc159
...
...
@@ -55,6 +55,7 @@ class EngineCoreRequest(
arrival_time
:
float
lora_request
:
Optional
[
LoRARequest
]
cache_salt
:
Optional
[
str
]
data_parallel_rank
:
Optional
[
int
]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
...
...
vllm/v1/engine/async_llm.py
View file @
c8dcc159
...
...
@@ -229,6 +229,7 @@ class AsyncLLM(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
RequestOutputCollector
:
"""Add new request to the AsyncLLM."""
...
...
@@ -245,7 +246,7 @@ class AsyncLLM(EngineClient):
prompt_str
,
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
tokenization_kwargs
,
trace_headers
,
prompt_adapter_request
,
priority
)
priority
,
data_parallel_rank
)
if
params
.
n
==
1
:
await
self
.
_add_request
(
request
,
prompt_str
,
None
,
0
,
queue
)
...
...
@@ -291,6 +292,7 @@ class AsyncLLM(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""
Main function called by the API server to kick off a request
...
...
@@ -321,6 +323,7 @@ class AsyncLLM(EngineClient):
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
data_parallel_rank
=
data_parallel_rank
,
)
# The output_handler task pushes items into the queue.
...
...
vllm/v1/engine/core_client.py
View file @
c8dcc159
...
...
@@ -982,7 +982,16 @@ class DPAsyncMPClient(AsyncMPClient):
resources
.
stats_update_task
=
asyncio
.
create_task
(
run_engine_stats_update_task
())
def
get_core_engine_for_request
(
self
)
->
CoreEngine
:
def
get_core_engine_for_request
(
self
,
dp_rank
:
Optional
[
int
]
=
None
)
->
CoreEngine
:
if
dp_rank
is
not
None
:
# engines are already in rank order
if
dp_rank
<
0
or
dp_rank
>=
len
(
self
.
core_engines
):
raise
ValueError
(
f
"Requested DP rank
{
dp_rank
}
is out of "
f
"range [0,
{
len
(
self
.
core_engines
)
}
)"
)
return
self
.
core_engines
[
dp_rank
]
if
not
self
.
lb_engines
:
return
self
.
core_engines
[
0
]
# TODO use P2C alg for larger DP sizes
...
...
@@ -1018,7 +1027,8 @@ class DPAsyncMPClient(AsyncMPClient):
request
.
current_wave
=
self
.
current_wave
request
.
client_index
=
self
.
client_index
chosen_engine
=
self
.
get_core_engine_for_request
()
chosen_engine
=
self
.
get_core_engine_for_request
(
request
.
data_parallel_rank
)
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
to_await
=
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
,
...
...
vllm/v1/engine/processor.py
View file @
c8dcc159
...
...
@@ -212,6 +212,7 @@ class Processor:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
tuple
[
Optional
[
str
],
EngineCoreRequest
]:
# TODO(woosuk): Support pooling models.
...
...
@@ -328,6 +329,7 @@ class Processor:
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
cache_salt
=
decoder_inputs
.
get
(
"cache_salt"
),
data_parallel_rank
=
data_parallel_rank
,
)
def
_validate_model_inputs
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment