Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500f26e6
Unverified
Commit
500f26e6
authored
Dec 19, 2025
by
inkcherry
Committed by
GitHub
Dec 18, 2025
Browse files
[Bugfix] fix DP-aware routing in OpenAI API requests (#29002)
Signed-off-by:
inkcherry
<
mingzhi.liu@amd.com
>
parent
686cbaac
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
68 additions
and
0 deletions
+68
-0
tests/entrypoints/openai/test_chat_error.py
tests/entrypoints/openai/test_chat_error.py
+1
-0
tests/entrypoints/openai/test_completion_error.py
tests/entrypoints/openai/test_completion_error.py
+1
-0
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+1
-0
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+61
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+1
-0
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+2
-0
No files found.
tests/entrypoints/openai/test_chat_error.py
View file @
500f26e6
...
@@ -76,6 +76,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
...
@@ -76,6 +76,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
lora_request
,
lora_request
,
trace_headers
,
trace_headers
,
priority
,
priority
,
data_parallel_rank
,
):
):
return
dict
(
engine_prompt
),
{}
return
dict
(
engine_prompt
),
{}
...
...
tests/entrypoints/openai/test_completion_error.py
View file @
500f26e6
...
@@ -73,6 +73,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
...
@@ -73,6 +73,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
lora_request
,
lora_request
,
trace_headers
,
trace_headers
,
priority
,
priority
,
data_parallel_rank
,
):
):
return
dict
(
engine_prompt
),
{}
return
dict
(
engine_prompt
),
{}
...
...
tests/entrypoints/openai/test_serving_chat.py
View file @
500f26e6
...
@@ -396,6 +396,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
...
@@ -396,6 +396,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
lora_request
,
lora_request
,
trace_headers
,
trace_headers
,
priority
,
priority
,
data_parallel_rank
,
):
):
return
dict
(
engine_prompt
),
{}
return
dict
(
engine_prompt
),
{}
...
...
tests/v1/engine/test_async_llm.py
View file @
500f26e6
...
@@ -11,6 +11,13 @@ from vllm import SamplingParams
...
@@ -11,6 +11,13 @@ from vllm import SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
ErrorResponse
,
)
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_models
import
BaseModelPath
,
OpenAIServingModels
from
vllm.inputs
import
PromptType
from
vllm.inputs
import
PromptType
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -484,6 +491,60 @@ async def test_dp_rank_argument():
...
@@ -484,6 +491,60 @@ async def test_dp_rank_argument():
pass
pass
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_header_dp_rank_argument
():
with
ExitStack
()
as
after
:
with
set_default_torch_num_threads
(
1
):
engine
=
AsyncLLM
.
from_engine_args
(
TEXT_ENGINE_ARGS
)
after
.
callback
(
engine
.
shutdown
)
MODEL_NAME
=
"test-model"
BASE_MODEL_PATHS
=
[
BaseModelPath
(
name
=
MODEL_NAME
,
model_path
=
MODEL_NAME
)]
# Create models first
models
=
OpenAIServingModels
(
engine_client
=
engine
,
base_model_paths
=
BASE_MODEL_PATHS
,
)
# Create serving chat instance
serving_chat
=
OpenAIServingChat
(
engine_client
=
engine
,
models
=
models
,
response_role
=
"assistant"
,
chat_template
=
None
,
chat_template_content_format
=
"auto"
,
request_logger
=
None
,
)
# Create a chat completion request
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
TEXT_PROMPT
}],
max_tokens
=
100
,
temperature
=
1.0
,
seed
=
33
,
)
# Test 1: Valid DP rank (0)
mock_raw_request
=
MagicMock
()
mock_raw_request
.
headers
=
{
"X-data-parallel-rank"
:
"0"
}
mock_raw_request
.
state
=
MagicMock
()
# Should succeed with valid rank
response
=
await
serving_chat
.
create_chat_completion
(
req
,
mock_raw_request
)
assert
isinstance
(
response
,
ChatCompletionResponse
),
(
"Expected a ChatCompletionResponse for valid DP rank"
)
# Test 2: Out-of-range DP rank (1)
mock_raw_request
.
headers
=
{
"X-data-parallel-rank"
:
"1"
}
# should return ErrorResponse for out-of-range rank
response2
=
await
serving_chat
.
create_chat_completion
(
req
,
mock_raw_request
)
assert
isinstance
(
response2
,
ErrorResponse
),
(
"Expected an ErrorResponse for out-of-range DP rank"
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_check_health
():
async
def
test_check_health
():
"""Test that check_health returns normally for healthy engine
"""Test that check_health returns normally for healthy engine
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
500f26e6
...
@@ -381,6 +381,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -381,6 +381,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
data_parallel_rank
=
data_parallel_rank
,
)
)
generator
=
self
.
engine_client
.
generate
(
generator
=
self
.
engine_client
.
generate
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
500f26e6
...
@@ -230,6 +230,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -230,6 +230,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request
=
lora_request
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
data_parallel_rank
=
data_parallel_rank
,
)
)
generator
=
self
.
engine_client
.
generate
(
generator
=
self
.
engine_client
.
generate
(
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
500f26e6
...
@@ -1231,6 +1231,7 @@ class OpenAIServing:
...
@@ -1231,6 +1231,7 @@ class OpenAIServing:
lora_request
:
LoRARequest
|
None
,
lora_request
:
LoRARequest
|
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
priority
:
int
,
priority
:
int
,
data_parallel_rank
:
int
|
None
=
None
,
)
->
tuple
[
EngineCoreRequest
,
dict
[
str
,
Any
]]:
)
->
tuple
[
EngineCoreRequest
,
dict
[
str
,
Any
]]:
"""Use the Processor to process inputs for AsyncLLM."""
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs
:
dict
[
str
,
Any
]
=
{}
tokenization_kwargs
:
dict
[
str
,
Any
]
=
{}
...
@@ -1246,6 +1247,7 @@ class OpenAIServing:
...
@@ -1246,6 +1247,7 @@ class OpenAIServing:
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
priority
=
priority
,
priority
=
priority
,
data_parallel_rank
=
data_parallel_rank
,
)
)
return
engine_request
,
tokenization_kwargs
return
engine_request
,
tokenization_kwargs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment