Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a2981c42
Unverified
Commit
a2981c42
authored
Oct 30, 2025
by
cong-meta
Committed by
GitHub
Oct 30, 2025
Browse files
[EP/DP][API Server] Enable DP-aware routing in OpenAI API requests (#24945)
Co-authored-by:
Cong Chen
<
prowindy@gmail.com
>
parent
4574d48b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
0 deletions
+99
-0
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+76
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+4
-0
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+4
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+15
-0
No files found.
tests/entrypoints/openai/test_serving_chat.py
View file @
a2981c42
...
@@ -651,3 +651,79 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
...
@@ -651,3 +651,79 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
await
serving_chat
.
create_chat_completion
(
req
)
await
serving_chat
.
create_chat_completion
(
req
)
engine_prompt
=
serving_chat
.
_process_inputs
.
await_args_list
[
1
].
args
[
1
]
engine_prompt
=
serving_chat
.
_process_inputs
.
await_args_list
[
1
].
args
[
1
]
assert
engine_prompt
.
get
(
"cache_salt"
)
==
"test_salt"
assert
engine_prompt
.
get
(
"cache_salt"
)
==
"test_salt"
@
pytest
.
mark
.
asyncio
async
def
test_serving_chat_data_parallel_rank_extraction
():
"""Test that data_parallel_rank is properly extracted from header and
passed to engine."""
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
# Mock the generate method to return an async generator
async
def
mock_generate
(
*
args
,
**
kwargs
):
# Yield a fake RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
yield
RequestOutput
(
request_id
=
"test-request"
,
prompt
=
"test prompt"
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
CompletionOutput
(
index
=
0
,
text
=
"test response"
,
token_ids
=
[
4
,
5
,
6
],
cumulative_logprob
=
0.0
,
logprobs
=
None
,
finish_reason
=
"stop"
,
stop_reason
=
None
,
)
],
finished
=
True
,
)
mock_engine
.
generate
=
AsyncMock
(
side_effect
=
mock_generate
)
serving_chat
=
_build_serving_chat
(
mock_engine
)
# Test when data_parallel_rank is present in header
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}],
)
# Mock request with X-data-parallel-rank header
mock_raw_request
=
MagicMock
()
mock_raw_request
.
headers
=
{
"X-data-parallel-rank"
:
"2"
}
mock_raw_request
.
state
=
MagicMock
()
with
suppress
(
Exception
):
await
serving_chat
.
create_chat_completion
(
req
,
mock_raw_request
)
# Verify that data_parallel_rank was passed to engine.generate
assert
"data_parallel_rank"
in
mock_engine
.
generate
.
call_args
.
kwargs
assert
mock_engine
.
generate
.
call_args
.
kwargs
[
"data_parallel_rank"
]
==
2
# Test when data_parallel_rank is not present (defaults to None)
req_no_dp
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 2+2?"
}],
)
# Mock request with no header
mock_raw_request_no_dp
=
MagicMock
()
mock_raw_request_no_dp
.
headers
=
{}
mock_raw_request_no_dp
.
state
=
MagicMock
()
with
suppress
(
Exception
):
await
serving_chat
.
create_chat_completion
(
req_no_dp
,
mock_raw_request_no_dp
)
# Verify that data_parallel_rank defaults to None
assert
"data_parallel_rank"
in
mock_engine
.
generate
.
call_args
.
kwargs
assert
mock_engine
.
generate
.
call_args
.
kwargs
[
"data_parallel_rank"
]
is
None
vllm/entrypoints/openai/serving_chat.py
View file @
a2981c42
...
@@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing):
if
raw_request
:
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
raw_request
.
state
.
request_metadata
=
request_metadata
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank
=
self
.
_get_data_parallel_rank
(
raw_request
)
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
...
@@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing):
priority
=
request
.
priority
,
priority
=
request
.
priority
,
prompt_text
=
prompt_text
,
prompt_text
=
prompt_text
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
data_parallel_rank
=
data_parallel_rank
,
)
)
generators
.
append
(
generator
)
generators
.
append
(
generator
)
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
a2981c42
...
@@ -141,6 +141,9 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -141,6 +141,9 @@ class OpenAIServingCompletion(OpenAIServing):
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank
=
self
.
_get_data_parallel_rank
(
raw_request
)
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
...
@@ -224,6 +227,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -224,6 +227,7 @@ class OpenAIServingCompletion(OpenAIServing):
priority
=
request
.
priority
,
priority
=
request
.
priority
,
prompt_text
=
prompt_text
,
prompt_text
=
prompt_text
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
data_parallel_rank
=
data_parallel_rank
,
)
)
generators
.
append
(
generator
)
generators
.
append
(
generator
)
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
a2981c42
...
@@ -1298,6 +1298,21 @@ class OpenAIServing:
...
@@ -1298,6 +1298,21 @@ class OpenAIServing:
return
raw_request
.
headers
.
get
(
"X-Request-Id"
,
default
)
return
raw_request
.
headers
.
get
(
"X-Request-Id"
,
default
)
@
staticmethod
def
_get_data_parallel_rank
(
raw_request
:
Request
|
None
)
->
int
|
None
:
"""Pulls the data parallel rank from a header, if provided"""
if
raw_request
is
None
:
return
None
rank_str
=
raw_request
.
headers
.
get
(
"X-data-parallel-rank"
)
if
rank_str
is
None
:
return
None
try
:
return
int
(
rank_str
)
except
ValueError
:
return
None
@
staticmethod
@
staticmethod
def
_get_decoded_token
(
def
_get_decoded_token
(
logprob
:
Logprob
,
logprob
:
Logprob
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment