Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c721b814
Commit
c721b814
authored
Feb 05, 2026
by
zhuwenwen
Browse files
sync v0.15.1
parent
d53fe7e5
Changes
328
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
380 additions
and
733 deletions
+380
-733
vllm/config/compilation.py
vllm/config/compilation.py
+1
-2
vllm/config/speculative.py
vllm/config/speculative.py
+0
-12
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+4
-98
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+4
-43
vllm/distributed/device_communicators/cpu_communicator.py
vllm/distributed/device_communicators/cpu_communicator.py
+7
-42
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+6
-38
vllm/distributed/device_communicators/mnnvl_compat.py
vllm/distributed/device_communicators/mnnvl_compat.py
+0
-13
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+7
-43
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+0
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+2
-24
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+0
-33
vllm/entrypoints/openai/chat_completion/serving.py
vllm/entrypoints/openai/chat_completion/serving.py
+45
-126
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+0
-11
vllm/entrypoints/openai/completion/serving.py
vllm/entrypoints/openai/completion/serving.py
+15
-2
vllm/entrypoints/openai/engine/protocol.py
vllm/entrypoints/openai/engine/protocol.py
+0
-4
vllm/entrypoints/openai/engine/serving.py
vllm/entrypoints/openai/engine/serving.py
+64
-35
vllm/entrypoints/openai/responses/serving.py
vllm/entrypoints/openai/responses/serving.py
+30
-60
vllm/entrypoints/pooling/classify/serving.py
vllm/entrypoints/pooling/classify/serving.py
+90
-58
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+102
-64
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+3
-24
No files found.
vllm/config/compilation.py
View file @
c721b814
...
...
@@ -280,10 +280,9 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
assume_32_bit_indexing
:
bool
=
Fals
e
assume_32_bit_indexing
:
bool
=
Tru
e
"""
whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
"""
def
compute_hash
(
self
)
->
str
:
...
...
vllm/config/speculative.py
View file @
c721b814
...
...
@@ -34,7 +34,6 @@ MTPModelTypes = Literal[
"mimo_mtp"
,
"glm4_moe_mtp"
,
"glm4_moe_lite_mtp"
,
"glm_ocr_mtp"
,
"ernie_mtp"
,
"exaone_moe_mtp"
,
"qwen3_next_mtp"
,
...
...
@@ -223,17 +222,6 @@ class SpeculativeConfig:
}
)
if
hf_config
.
architectures
[
0
]
==
"GlmOcrForConditionalGeneration"
:
hf_config
.
model_type
=
"glm_ocr_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
(
{
"num_hidden_layers"
:
0
,
"n_predict"
:
n_predict
,
"architectures"
:
[
"GlmOcrMTPModel"
],
}
)
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
hf_config
.
model_type
=
"ernie_mtp"
if
hf_config
.
model_type
==
"ernie_mtp"
:
...
...
vllm/distributed/device_communicators/all2all.py
View file @
c721b814
...
...
@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return
buffer
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase):
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
extra_tensors
is
not
None
:
raise
NotImplementedError
(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size
=
self
.
tp_group
.
world_size
if
is_sequence_parallel
else
1
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
cu_tokens_across_sp_cpu
=
dp_metadata
.
cu_tokens_across_sp
(
sp_size
)
hidden_states
=
self
.
naive_multicast
(
hidden_states
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_weights
=
self
.
naive_multicast
(
topk_weights
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_ids
=
self
.
naive_multicast
(
topk_ids
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def
__init__
(
self
,
cpu_group
):
super
().
__init__
(
cpu_group
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
sizes
=
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
assert
sizes
is
not
None
dist_group
=
get_ep_group
()
if
is_sequence_parallel
else
get_dp_group
()
assert
sizes
[
dist_group
.
rank_in_group
]
==
hidden_states
.
shape
[
0
]
tensors_to_gather
=
[
hidden_states
,
topk_weights
,
topk_ids
]
if
extra_tensors
is
not
None
:
tensors_to_gather
.
extend
(
extra_tensors
)
gathered_tensors
=
dist_group
.
all_gatherv
(
tensors_to_gather
,
dim
=
0
,
sizes
=
sizes
,
)
hidden_states
=
gathered_tensors
[
0
]
topk_weights
=
gathered_tensors
[
1
]
topk_ids
=
gathered_tensors
[
2
]
if
extra_tensors
is
None
:
return
hidden_states
,
topk_weights
,
topk_ids
return
hidden_states
,
topk_weights
,
topk_ids
,
gathered_tensors
[
3
:]
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx
.
AllToAll
.
internode
if
self
.
internode
else
pplx
.
AllToAll
.
intranode
,
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
raise
NotImplementedError
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
raise
NotImplementedError
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
vllm/distributed/device_communicators/base_device_communicator.py
View file @
c721b814
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
threading
from
typing
import
Any
from
weakref
import
WeakValueDictionary
import
torch
...
...
@@ -63,32 +64,13 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise
NotImplementedError
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
)
->
Any
:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
...
...
@@ -298,7 +280,7 @@ class DeviceCommunicatorBase:
for
module
in
moe_modules
:
module
.
maybe_init_modular_kernel
()
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -312,29 +294,8 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
router_logits
,
extra_tensors
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
topk_weights
,
topk_ids
,
extra_tensors
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
vllm/distributed/device_communicators/cpu_communicator.py
View file @
c721b814
...
...
@@ -130,65 +130,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
router_logits
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
,
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
)
return
hidden_states
class
_CPUSHMDistributed
:
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
c721b814
...
...
@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return
output_list
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -332,52 +332,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
router_logits
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
,
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
)
return
hidden_states
vllm/distributed/device_communicators/mnnvl_compat.py
View file @
c721b814
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch.distributed
as
dist
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
...
...
@@ -24,15 +23,3 @@ class CustomCommunicator(CommBackend):
gathered
=
[
None
]
*
self
.
Get_size
()
dist
.
all_gather_object
(
gathered
,
data
,
group
=
self
.
_group
)
return
gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def
bcast
(
self
,
data
:
Any
,
root
:
int
)
->
Any
:
raise
NotImplementedError
def
barrier
(
self
)
->
None
:
raise
NotImplementedError
def
Split
(
self
,
color
:
int
,
key
:
int
)
->
"CustomCommunicator"
:
return
self
vllm/distributed/device_communicators/xpu_communicator.py
View file @
c721b814
...
...
@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase):
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
def
dispatch_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
router_logits
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
,
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
)
return
hidden_states
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
c721b814
...
...
@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class
NixlConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
...
...
vllm/distributed/parallel_state.py
View file @
c721b814
...
...
@@ -1000,7 +1000,7 @@ class GroupCoordinator:
if
self
.
device_communicator
is
not
None
:
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -1011,7 +1011,7 @@ class GroupCoordinator:
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
_router_logits
(
return
self
.
device_communicator
.
dispatch
(
# type: ignore[call-arg]
hidden_states
,
router_logits
,
is_sequence_parallel
,
...
...
@@ -1020,28 +1020,6 @@ class GroupCoordinator:
else
:
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
):
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
,
)
else
:
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
self
,
hidden_states
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
vllm/entrypoints/openai/api_server.py
View file @
c721b814
...
...
@@ -264,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None:
return
None
def
get_uvicorn_log_config
(
args
:
Namespace
)
->
dict
|
None
:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config
=
load_log_config
(
args
.
log_config_file
)
if
log_config
is
not
None
:
return
log_config
# If endpoints to filter are specified, create a config with the filter
if
args
.
disable_access_log_for_endpoints
:
from
vllm.logging_utils
import
create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths
=
[
p
.
strip
()
for
p
in
args
.
disable_access_log_for_endpoints
.
split
(
","
)
if
p
.
strip
()
]
return
create_uvicorn_log_config
(
excluded_paths
=
excluded_paths
,
log_level
=
args
.
uvicorn_log_level
,
)
return
None
class
AuthenticationMiddleware
:
"""
Pure ASGI middleware that authenticates each request by checking
...
...
vllm/entrypoints/openai/chat_completion/serving.py
View file @
c721b814
...
...
@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage
,
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
PromptTokenUsageInfo
,
RequestResponseMetadata
,
ToolCall
,
...
...
@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from
vllm.entrypoints.openai.utils
import
maybe_filter_parallel_tool_calls
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.parse
import
get_prompt_components
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
...
...
@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing):
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
self
.
enable_force_include_usage
=
enable_force_include_usage
self
.
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
if
self
.
model_config
.
hf_config
.
model_type
==
"kimi_k2"
:
self
.
tool_call_id_type
=
"kimi_k2"
else
:
self
.
tool_call_id_type
=
"random"
self
.
use_harmony
=
self
.
model_config
.
hf_config
.
model_type
==
"gpt_oss"
if
self
.
use_harmony
:
if
"stop_token_ids"
not
in
self
.
default_sampling_params
:
...
...
@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions
()
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides
=
getattr
(
self
.
model_config
,
"hf_overrides"
,
None
)
if
self
.
model_config
.
hf_text_config
.
model_type
==
"kimi_k2"
or
(
isinstance
(
hf_overrides
,
dict
)
and
hf_overrides
.
get
(
"model_type"
)
==
"kimi_k2"
):
self
.
tool_call_id_type
=
"kimi_k2"
else
:
self
.
tool_call_id_type
=
"random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
...
...
@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
# type: ignore[arg-type]
truncate_tool_call_ids
(
request
)
# type: ignore[arg-type]
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
validate_request_params
(
request
)
# Check if tool parsing is unavailable (common condition)
...
...
@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
_
,
_
=
self
.
_get_prompt_components
(
engine_prompt
)
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id
=
(
request_id
if
len
(
engine_prompts
)
==
1
else
f
"
{
request_id
}
_
{
i
}
"
)
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
request
=
request
,
prompt
=
engine_prompt
,
input_length
=
len
(
engine_prompt
[
"prompt_token_ids"
])
,
default_sampling_params
=
self
.
default_sampling_params
,
)
...
...
@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
tokenizer
=
self
.
renderer
.
tokenizer
assert
tokenizer
is
not
None
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
...
...
@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
model_name
:
str
,
conversation
:
list
[
ConversationMessage
],
tokenizer
:
TokenizerLike
,
tokenizer
:
TokenizerLike
|
None
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
from
vllm.tokenizers.mistral
import
MistralTokenizer
created_time
=
int
(
time
.
time
())
chunk_object_type
:
Final
=
"chat.completion.chunk"
first_iteration
=
True
...
...
@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing):
)
reasoning_parser
=
self
.
reasoning_parser
(
tokenizer
,
chat_template_kwargs
=
chat_template_kwargs
or
{}
,
# type: ignore[call-arg]
chat_template_kwargs
=
chat_template_kwargs
,
# type: ignore[call-arg]
)
except
RuntimeError
as
e
:
logger
.
exception
(
"Error in reasoning parser creation."
)
...
...
@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing):
index
=
i
,
)
else
:
# Generate ID based on tokenizer type
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_id
=
MistralToolCall
.
generate_random_id
()
else
:
tool_call_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tool_choice_function_name
,
idx
=
history_tool_call_cnt
,
)
delta_tool_call
=
DeltaToolCall
(
id
=
tool_call_id
,
id
=
make_
tool_call_id
()
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
tool_choice_function_name
,
...
...
@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
model_name
:
str
,
conversation
:
list
[
ConversationMessage
],
tokenizer
:
TokenizerLike
,
tokenizer
:
TokenizerLike
|
None
,
request_metadata
:
RequestResponseMetadata
,
)
->
ErrorResponse
|
ChatCompletionResponse
:
from
vllm.tokenizers.mistral
import
MistralTokenizer
created_time
=
int
(
time
.
time
())
final_res
:
RequestOutput
|
None
=
None
...
...
@@ -1543,85 +1524,39 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class
=
(
MistralToolCall
if
isinstance
(
tokenizer
,
MistralTokenizer
)
else
ToolCall
)
if
self
.
use_harmony
:
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
,
tool_calls
=
tool_calls
if
tool_calls
else
[],
)
elif
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
(
if
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
(
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
)
and
request
.
tool_choice
!=
"required"
):
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
)
# if the request uses tools and specified a tool choice
elif
(
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
):
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
tool_call_class_items
=
[]
for
idx
,
tc
in
enumerate
(
tool_calls
):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if
tc
.
id
:
tool_call_class_items
.
append
(
tool_call_class
(
id
=
tc
.
id
,
function
=
tc
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_class_items
.
append
(
tool_call_class
(
function
=
tc
))
else
:
generated_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tc
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
)
tool_call_class_items
.
append
(
tool_call_class
(
id
=
generated_id
,
function
=
tc
)
)
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
""
,
tool_calls
=
tool_call_class
_items
,
tool_calls
=
[
tool_call_class
(
function
=
tc
)
for
tc
in
tool_calls
]
,
)
elif
request
.
tool_choice
and
request
.
tool_choice
==
"required"
:
tool_call_class_items
=
[]
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
for
idx
,
tool_call
in
enumerate
(
tool_calls
):
# Use native ID if available,
# otherwise generate ID with correct id_type
if
tool_call
.
id
:
tool_call_class_items
.
append
(
tool_call_class
(
id
=
tool_call
.
id
,
function
=
tool_call
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_class_items
.
append
(
tool_call_class
(
function
=
tool_call
)
)
else
:
generated_id
=
make_tool_call_id
(
for
tool_call
in
tool_calls
:
tool_call_class_items
.
append
(
tool_call_class
(
id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tool_call
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
)
tool_call_class_items
.
append
(
tool_call_class
(
id
=
generated_id
,
function
=
tool_call
)
)
idx
=
history_tool_call_cnt
,
)
,
function
=
tool_call
,
)
)
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
role
=
role
,
...
...
@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls
auto_tools_called
=
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
if
tool_calls
:
tool_call_items
=
[]
for
idx
,
tc
in
enumerate
(
tool_calls
):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if
tc
.
id
:
tool_call_items
.
append
(
tool_call_class
(
id
=
tc
.
id
,
function
=
tc
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_items
.
append
(
tool_call_class
(
function
=
tc
))
else
:
generated_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tc
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
)
tool_call_items
.
append
(
tool_call_class
(
id
=
generated_id
,
function
=
tc
)
)
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
,
tool_calls
=
tool_call_items
,
tool_calls
=
[
ToolCall
(
function
=
tc
,
type
=
"function"
,
)
for
tc
in
tool_calls
],
)
else
:
...
...
@@ -1784,11 +1701,13 @@ class OpenAIServingChat(OpenAIServing):
elif
choice
.
message
.
tool_calls
:
# For tool calls, log the function name and arguments
tool_call_descriptions
=
[]
for
tc
in
choice
.
message
.
tool_calls
:
# type: ignore
function_call
:
FunctionCall
=
tc
.
function
# type: ignore
tool_call_descriptions
.
append
(
f
"
{
function_call
.
name
}
(
{
function_call
.
arguments
}
)"
)
for
tc
in
choice
.
message
.
tool_calls
:
if
hasattr
(
tc
.
function
,
"name"
)
and
hasattr
(
tc
.
function
,
"arguments"
):
tool_call_descriptions
.
append
(
f
"
{
tc
.
function
.
name
}
(
{
tc
.
function
.
arguments
}
)"
)
tool_calls_str
=
", "
.
join
(
tool_call_descriptions
)
output_text
=
f
"[tool_calls:
{
tool_calls_str
}
]"
...
...
@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
# type: ignore[arg-type]
maybe_serialize_tool_calls
(
request
)
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
...
...
@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message.
if
request
.
tools
:
dev_msg
=
get_developer_message
(
tools
=
request
.
tools
if
should_include_tools
else
None
# type: ignore[arg-type]
tools
=
request
.
tools
if
should_include_tools
else
None
)
messages
.
append
(
dev_msg
)
...
...
@@ -2009,4 +1928,4 @@ class OpenAIServingChat(OpenAIServing):
if
request
.
cache_salt
is
not
None
:
engine_prompt
[
"cache_salt"
]
=
request
.
cache_salt
return
messages
,
[
engine_prompt
]
return
messages
,
[
engine_prompt
]
\ No newline at end of file
vllm/entrypoints/openai/cli_args.py
View file @
c721b814
...
...
@@ -85,12 +85,6 @@ class FrontendArgs:
"""Log level for uvicorn."""
disable_uvicorn_access_log
:
bool
=
False
"""Disable uvicorn access log."""
disable_access_log_for_endpoints
:
str
|
None
=
None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials
:
bool
=
False
"""Allow credentials."""
allowed_origins
:
list
[
str
]
=
field
(
default_factory
=
lambda
:
[
"*"
])
...
...
@@ -250,11 +244,6 @@ class FrontendArgs:
del
frontend_kwargs
[
"middleware"
][
"nargs"
]
frontend_kwargs
[
"middleware"
][
"default"
]
=
[]
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if
"nargs"
in
frontend_kwargs
[
"disable_access_log_for_endpoints"
]:
del
frontend_kwargs
[
"disable_access_log_for_endpoints"
][
"nargs"
]
# Special case: Tool call parser shows built-in options.
valid_tool_parsers
=
list
(
ToolParserManager
.
list_registered
())
parsers_str
=
","
.
join
(
valid_tool_parsers
)
...
...
vllm/entrypoints/openai/completion/serving.py
View file @
c721b814
...
...
@@ -163,12 +163,25 @@ class OpenAIServingCompletion(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
prompt_token_ids
,
prompt_embeds
=
(
self
.
_get_prompt_components
(
engine_prompt
)
)
input_length
=
None
if
prompt_token_ids
is
not
None
:
input_length
=
len
(
prompt_token_ids
)
elif
prompt_embeds
is
not
None
:
input_length
=
len
(
prompt_embeds
)
else
:
raise
NotImplementedError
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
request
=
request
,
prompt
=
engine_prompt
,
input_length
=
input_length
,
default_sampling_params
=
self
.
default_sampling_params
,
)
...
...
vllm/entrypoints/openai/engine/protocol.py
View file @
c721b814
...
...
@@ -218,10 +218,6 @@ def get_logits_processors(
class
FunctionCall
(
OpenAIBaseModel
):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id
:
str
|
None
=
Field
(
default
=
None
,
exclude
=
True
)
name
:
str
arguments
:
str
...
...
vllm/entrypoints/openai/engine/serving.py
View file @
c721b814
...
...
@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import (
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
ClassificationCompletionRequest
,
ClassificationRequest
,
ClassificationResponse
,
)
from
vllm.entrypoints.pooling.embed.protocol
import
(
EmbeddingBytesResponse
,
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
)
from
vllm.entrypoints.pooling.pooling.protocol
import
(
...
...
@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest
,
TokenizeResponse
,
)
from
vllm.entrypoints.utils
import
(
_validate_truncation_size
,
get_max_tokens
,
sanitize_message
,
)
from
vllm.entrypoints.utils
import
_validate_truncation_size
,
sanitize_message
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.parse
import
(
PromptComponents
,
get_prompt_components
,
is_explicit_encoder_decoder_prompt
,
)
...
...
@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = (
CompletionResponse
|
ChatCompletionResponse
|
EmbeddingResponse
|
EmbeddingBytesResponse
|
TranscriptionResponse
|
TokenizeResponse
|
PoolingResponse
...
...
@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@
dataclass
(
kw_only
=
True
)
class
ServeContext
(
Generic
[
RequestT
]):
class
RequestProcessingMixin
:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts
:
list
[
TokensPrompt
]
|
None
=
field
(
default_factory
=
list
)
@
dataclass
(
kw_only
=
True
)
class
ResponseGenerationMixin
:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator
:
(
AsyncGenerator
[
tuple
[
int
,
RequestOutput
|
PoolingRequestOutput
],
None
]
|
None
)
=
None
final_res_batch
:
list
[
RequestOutput
|
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ServeContext
(
RequestProcessingMixin
,
ResponseGenerationMixin
,
Generic
[
RequestT
]):
request
:
RequestT
raw_request
:
Request
|
None
=
None
model_name
:
str
request_id
:
str
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
LoRARequest
|
None
=
None
engine_prompts
:
list
[
TokensPrompt
]
|
None
=
None
result_generator
:
AsyncGenerator
[
tuple
[
int
,
PoolingRequestOutput
],
None
]
|
None
=
(
None
)
final_res_batch
:
list
[
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ClassificationServeContext
(
ServeContext
[
ClassificationRequest
]):
pass
@
dataclass
(
kw_only
=
True
)
class
EmbeddingServeContext
(
ServeContext
[
EmbeddingRequest
]):
chat_template
:
str
|
None
=
None
chat_template_content_format
:
ChatTemplateContentFormatOption
class
OpenAIServing
:
...
...
@@ -578,7 +605,10 @@ class OpenAIServing:
self
,
ctx
:
ServeContext
,
)
->
AnyResponse
|
ErrorResponse
:
async
for
response
in
self
.
_pipeline
(
ctx
):
generation
:
AsyncGenerator
[
AnyResponse
|
ErrorResponse
,
None
]
generation
=
self
.
_pipeline
(
ctx
)
async
for
response
in
generation
:
return
response
return
self
.
create_error_response
(
"No response yielded from pipeline"
)
...
...
@@ -637,7 +667,9 @@ class OpenAIServing:
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
"""Schedule the request and get the result generator."""
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]
]
=
[]
try
:
trace_headers
=
(
...
...
@@ -691,7 +723,7 @@ class OpenAIServing:
return
self
.
create_error_response
(
"Engine prompts not available"
)
num_prompts
=
len
(
ctx
.
engine_prompts
)
final_res_batch
:
list
[
PoolingRequestOutput
|
None
]
final_res_batch
:
list
[
RequestOutput
|
PoolingRequestOutput
|
None
]
final_res_batch
=
[
None
]
*
num_prompts
if
ctx
.
result_generator
is
None
:
...
...
@@ -979,7 +1011,7 @@ class OpenAIServing:
def
_validate_input
(
self
,
request
:
objec
t
,
request
:
AnyReques
t
,
input_ids
:
list
[
int
],
input_text
:
str
,
)
->
TokensPrompt
:
...
...
@@ -1290,7 +1322,7 @@ class OpenAIServing:
priority
:
int
=
0
,
**
kwargs
,
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
_
,
_
=
self
.
_
get_prompt_components
(
engine_prompt
)
orig_priority
=
priority
sub_request
=
0
...
...
@@ -1341,12 +1373,10 @@ class OpenAIServing:
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids
and update sampling_params
.
# Render the next prompt token ids.
if
isinstance
(
context
,
(
HarmonyContext
,
StreamingHarmonyContext
)):
token_ids
=
context
.
render_for_completion
()
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
token_ids
)
sampling_params
.
max_tokens
=
self
.
max_model_len
-
len
(
token_ids
)
prompt_token_ids
=
context
.
render_for_completion
()
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
)
elif
isinstance
(
context
,
ParsableContext
):
engine_prompts
=
await
self
.
_render_next_turn
(
context
.
request
,
...
...
@@ -1358,19 +1388,19 @@ class OpenAIServing:
context
.
chat_template_content_format
,
)
engine_prompt
=
engine_prompts
[
0
]
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
sampling_params
.
max_tokens
=
get_max_tokens
(
self
.
max_model_len
,
context
.
request
,
engine_prompt
,
self
.
default_sampling_params
,
# type: ignore
)
prompt_text
,
_
,
_
=
self
.
_get_prompt_components
(
engine_prompt
)
# Update the sampling params.
sampling_params
.
max_tokens
=
self
.
max_model_len
-
len
(
engine_prompt
[
"prompt_token_ids"
]
)
# OPTIMIZATION
priority
=
orig_priority
-
1
sub_request
+=
1
def
_get_prompt_components
(
self
,
prompt
:
PromptType
)
->
PromptComponents
:
return
get_prompt_components
(
prompt
)
def
_log_inputs
(
self
,
request_id
:
str
,
...
...
@@ -1381,7 +1411,7 @@ class OpenAIServing:
if
self
.
request_logger
is
None
:
return
prompt
,
prompt_token_ids
,
prompt_embeds
=
get_prompt_components
(
inputs
)
prompt
,
prompt_token_ids
,
prompt_embeds
=
self
.
_
get_prompt_components
(
inputs
)
self
.
request_logger
.
log_inputs
(
request_id
,
...
...
@@ -1495,7 +1525,6 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls.
function_calls
.
extend
(
FunctionCall
(
id
=
tool_call
.
id
,
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
)
...
...
@@ -1548,4 +1577,4 @@ def clamp_prompt_logprobs(
for
logprob_values
in
logprob_dict
.
values
():
if
logprob_values
.
logprob
==
float
(
"-inf"
):
logprob_values
.
logprob
=
-
9999.0
return
prompt_logprobs
return
prompt_logprobs
\ No newline at end of file
vllm/entrypoints/openai/responses/serving.py
View file @
c721b814
...
...
@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
ChatTemplateContentFormatOption
,
make_tool_call_id
,
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.mcp.tool_server
import
ToolServer
...
...
@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types
,
should_continue_final_message
,
)
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
...
...
@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing):
self
.
default_sampling_params
[
"stop_token_ids"
].
extend
(
get_stop_tokens_for_assistant_actions
()
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides
=
getattr
(
self
.
model_config
,
"hf_overrides"
,
None
)
if
self
.
model_config
.
hf_text_config
.
model_type
==
"kimi_k2"
or
(
isinstance
(
hf_overrides
,
dict
)
and
hf_overrides
.
get
(
"model_type"
)
==
"kimi_k2"
):
self
.
tool_call_id_type
=
"kimi_k2"
else
:
self
.
tool_call_id_type
=
"random"
self
.
enable_auto_tools
=
enable_auto_tools
# set up tool use
self
.
tool_parser
=
self
.
_get_tool_parser
(
...
...
@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing):
if
maybe_error
is
not
None
:
return
maybe_error
default_max_tokens
=
get_max_tokens
(
self
.
max_model_len
,
request
,
engine_prompt
,
self
.
default_sampling_params
,
default_max_tokens
=
self
.
max_model_len
-
len
(
engine_prompt
[
"prompt_token_ids"
]
)
sampling_params
=
request
.
to_sampling_params
(
...
...
@@ -970,28 +954,25 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools
=
self
.
enable_auto_tools
,
tool_parser_cls
=
self
.
tool_parser
,
)
if
content
or
(
self
.
use_harmony
and
tool_calls
):
res_text_part
=
None
if
content
:
res_text_part
=
ResponseOutputText
(
text
=
content
,
annotations
=
[],
# TODO
type
=
"output_text"
,
logprobs
=
(
self
.
_create_response_logprobs
(
token_ids
=
final_output
.
token_ids
,
logprobs
=
final_output
.
logprobs
,
tokenizer
=
tokenizer
,
top_logprobs
=
request
.
top_logprobs
,
)
if
request
.
is_include_output_logprobs
()
else
None
),
)
if
content
:
output_text
=
ResponseOutputText
(
text
=
content
,
annotations
=
[],
# TODO
type
=
"output_text"
,
logprobs
=
(
self
.
_create_response_logprobs
(
token_ids
=
final_output
.
token_ids
,
logprobs
=
final_output
.
logprobs
,
tokenizer
=
tokenizer
,
top_logprobs
=
request
.
top_logprobs
,
)
if
request
.
is_include_output_logprobs
()
else
None
),
)
message_item
=
ResponseOutputMessage
(
id
=
f
"msg_
{
random_uuid
()
}
"
,
content
=
[
res_text_part
]
if
res_text_part
else
[
],
content
=
[
output_text
],
role
=
"assistant"
,
status
=
"completed"
,
type
=
"message"
,
...
...
@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing):
if
message_item
:
outputs
.
append
(
message_item
)
if
tool_calls
:
# We use a simple counter for history_tool_call_count because
# we don't track the history of tool calls in the Responses API yet.
# This means that the tool call index will start from 0 for each
# request.
tool_call_items
=
[]
for
history_tool_call_cnt
,
tool_call
in
enumerate
(
tool_calls
):
tool_call_items
.
append
(
ResponseFunctionToolCall
(
id
=
f
"fc_
{
random_uuid
()
}
"
,
call_id
=
tool_call
.
id
if
tool_call
.
id
else
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tool_call
.
name
,
idx
=
history_tool_call_cnt
,
),
type
=
"function_call"
,
status
=
"completed"
,
name
=
tool_call
.
name
,
arguments
=
tool_call
.
arguments
,
)
tool_call_items
=
[
ResponseFunctionToolCall
(
id
=
f
"fc_
{
random_uuid
()
}
"
,
call_id
=
f
"call_
{
random_uuid
()
}
"
,
type
=
"function_call"
,
status
=
"completed"
,
name
=
tool_call
.
name
,
arguments
=
tool_call
.
arguments
,
)
for
tool_call
in
tool_calls
]
outputs
.
extend
(
tool_call_items
)
return
outputs
...
...
@@ -2589,4 +2559,4 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number
=-
1
,
response
=
final_response
,
)
)
)
\ No newline at end of file
vllm/entrypoints/pooling/classify/serving.py
View file @
c721b814
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
http
import
HTTPStatus
from
typing
import
Final
,
cast
from
typing
import
cast
import
jinja2
import
numpy
as
np
...
...
@@ -11,8 +11,18 @@ from fastapi import Request
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
,
ServeContext
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
(
ClassificationServeContext
,
OpenAIServing
,
ServeContext
,
)
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
...
...
@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams
logger
=
init_logger
(
__name__
)
ClassificationServeContext
=
ServeContext
[
ClassificationRequest
]
class
ServingClassification
(
OpenAIServing
):
request_id_prefix
=
"classify"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
class
ClassificationMixin
(
OpenAIServing
):
chat_template
:
str
|
None
chat_template_content_format
:
ChatTemplateContentFormatOption
trust_request_chat_template
:
bool
async
def
_preprocess
(
self
,
ctx
:
Classification
ServeContext
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx
=
cast
(
ClassificationServeContext
,
ctx
)
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
ClassificationChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
request_obj
=
ctx
.
request
if
isinstance
(
request_obj
,
ClassificationChatRequest
):
chat_request
=
request_obj
messages
=
chat_request
.
messages
trust_request_chat_template
=
getattr
(
self
,
"trust_request_chat_template"
,
False
,
)
if
error_check_ret
:
return
error_check_ret
ret
=
self
.
_validate_chat_template
(
request_chat_template
=
chat_request
.
chat_template
,
chat_template_kwargs
=
chat_request
.
chat_template_kwargs
,
trust_request_chat_template
=
trust_request_chat_template
,
)
if
ret
:
return
ret
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
c
tx
.
request
,
c
ast
(
ChatCompletionRequest
,
chat_
request
)
,
self
.
renderer
,
ctx
.
request
.
messages
,
chat_template
=
ctx
.
request
.
chat_template
or
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
continue_final_message
=
ctx
.
request
.
continue_final_message
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
messages
,
chat_template
=
(
chat_request
.
chat_template
or
getattr
(
self
,
"chat_template"
,
None
)
),
chat_template_content_format
=
cast
(
ChatTemplateContentFormatOption
,
getattr
(
self
,
"chat_template_content_format"
,
"auto"
),
),
add_generation_prompt
=
chat_request
.
add_generation_prompt
,
continue_final_message
=
chat_request
.
continue_final_message
,
add_special_tokens
=
chat_request
.
add_special_tokens
,
)
ctx
.
engine_prompts
=
engine_prompts
elif
isinstance
(
ctx
.
request
,
ClassificationCompletionRequest
):
input_data
=
ctx
.
request
.
input
elif
isinstance
(
request_obj
,
ClassificationCompletionRequest
):
completion_request
=
request_obj
input_data
=
completion_request
.
input
if
input_data
in
(
None
,
""
):
return
self
.
create_error_response
(
"Input or messages must be provided"
,
...
...
@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing):
prompt_input
=
cast
(
str
|
list
[
str
],
input_data
)
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
prompt_input
,
config
=
self
.
_build_render_config
(
c
tx
.
request
),
config
=
self
.
_build_render_config
(
c
ompletion_
request
),
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
self
.
create_error_response
(
"Invalid classification request type"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
return
None
...
...
@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing):
def
_build_response
(
self
,
ctx
:
Classification
ServeContext
,
ctx
:
ServeContext
,
)
->
ClassificationResponse
|
ErrorResponse
:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
id2label
=
getattr
(
self
.
model_config
.
hf_config
,
"id2label"
,
{})
ctx
=
cast
(
ClassificationServeContext
,
ctx
)
items
:
list
[
ClassificationData
]
=
[]
num_prompt_tokens
=
0
...
...
@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing):
probs
=
classify_res
.
probs
predicted_index
=
int
(
np
.
argmax
(
probs
))
label
=
id2label
.
get
(
predicted_index
)
label
=
getattr
(
self
.
model_config
.
hf_config
,
"id2label"
,
{}).
get
(
predicted_index
)
item
=
ClassificationData
(
index
=
idx
,
...
...
@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
)
class
ServingClassification
(
ClassificationMixin
):
request_id_prefix
=
"classify"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
create_classify
(
self
,
request
:
ClassificationRequest
,
...
...
@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing):
request_id
=
request_id
,
)
return
await
s
elf
.
handle
(
ctx
)
# type: ignore
[return-value]
return
await
s
uper
()
.
handle
(
ctx
)
# type: ignore
def
_create_pooling_params
(
self
,
ctx
:
Classification
ServeContext
,
ctx
:
ServeContext
[
Classification
Request
]
,
)
->
PoolingParams
|
ErrorResponse
:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
...
...
@@ -198,4 +230,4 @@ class ServingClassification(OpenAIServing):
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
pooling_params
return
pooling_params
\ No newline at end of file
vllm/entrypoints/pooling/embed/serving.py
View file @
c721b814
...
...
@@ -6,13 +6,21 @@ from typing import Any, Final, cast
import
torch
from
fastapi
import
Request
from
typing_extensions
import
assert_never
from
fastapi.responses
import
Response
from
typing_extensions
import
assert_never
,
override
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
,
ServeContext
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
(
EmbeddingServeContext
,
OpenAIServing
,
ServeContext
,
)
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.embed.protocol
import
(
EmbeddingBytesResponse
,
...
...
@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import (
from
vllm.entrypoints.renderer
import
RenderConfig
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
PoolingOutput
,
PoolingRequestOutput
,
RequestOutput
,
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.collection_utils
import
chunk_list
from
vllm.utils.serial_utils
import
(
EmbedDType
,
EncodingFormat
,
Endianness
,
encode_pooling_bytes
,
encode_pooling_output
,
)
...
...
@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import (
logger
=
init_logger
(
__name__
)
EmbeddingServeContext
=
ServeContext
[
EmbeddingRequest
]
class
OpenAIServingEmbedding
(
OpenAIServing
):
request_id_prefix
=
"embd"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
class
EmbeddingMixin
(
OpenAIServing
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
pooler_config
=
self
.
model_config
.
pooler_config
...
...
@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing):
else
None
)
@
override
async
def
_preprocess
(
self
,
ctx
:
Embedding
ServeContext
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
ctx
.
engine_prompts
=
await
self
.
_preprocess_chat
(
ctx
.
request
,
self
.
renderer
,
ctx
.
request
.
messages
,
chat_template
=
ctx
.
request
.
chat_template
or
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
chat_template
=
ctx
.
request
.
chat_template
or
ctx
.
chat_template
,
chat_template_content_format
=
ctx
.
chat_template_content_format
,
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
continue_final_message
=
ctx
.
request
.
continue_final_message
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
el
if
isinstance
(
ctx
.
request
,
EmbeddingCompletionRequest
)
:
el
se
:
renderer
=
self
.
_get_completion_renderer
()
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
ctx
.
request
.
input
,
config
=
self
.
_build_render_config
(
ctx
.
request
),
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
None
except
(
ValueError
,
TypeError
)
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
...
...
@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
)
@
override
def
_build_response
(
self
,
ctx
:
Embedding
ServeContext
,
)
->
EmbeddingResponse
|
EmbeddingBytes
Response
|
ErrorResponse
:
final_res_batch_checked
=
ctx
.
final_res_batch
ctx
:
ServeContext
,
)
->
EmbeddingResponse
|
Response
|
ErrorResponse
:
final_res_batch_checked
=
cast
(
list
[
PoolingRequestOutput
],
ctx
.
final_res_batch
)
encoding_format
=
ctx
.
request
.
encoding_format
embed_dtype
=
ctx
.
request
.
embed_dtype
endianness
=
ctx
.
request
.
endianness
encoding_format
:
EncodingFormat
=
ctx
.
request
.
encoding_format
embed_dtype
:
EmbedDType
=
ctx
.
request
.
embed_dtype
endianness
:
Endianness
=
ctx
.
request
.
endianness
def
encode_float_base64
():
items
:
list
[
EmbeddingResponseData
]
=
[]
...
...
@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing):
self
,
ctx
:
EmbeddingServeContext
,
token_ids
:
list
[
int
],
pooling_params
:
PoolingParams
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
pooling_params
,
trace_headers
,
prompt_idx
:
int
,
)
->
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]:
"""Process a single prompt using chunked processing."""
...
...
@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing):
def
_validate_input
(
self
,
request
:
object
,
request
,
input_ids
:
list
[
int
],
input_text
:
str
,
)
->
TokensPrompt
:
...
...
@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params
:
PoolingParams
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
prompt_index
:
int
,
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]:
"""Create a generator for a single prompt using standard processing."""
request_id_item
=
f
"
{
ctx
.
request_id
}
-
{
prompt_index
}
"
...
...
@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing):
priority
=
getattr
(
ctx
.
request
,
"priority"
,
0
),
)
@
override
async
def
_prepare_generators
(
self
,
ctx
:
ServeContext
,
...
...
@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing):
return
await
super
().
_prepare_generators
(
ctx
)
# Custom logic for chunked processing
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]
]
=
[]
try
:
trace_headers
=
(
...
...
@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
@
override
async
def
_collect_batch
(
self
,
ctx
:
Embedding
ServeContext
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
"""Collect and aggregate batch results
with support for chunked processing.
...
...
@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing):
minimize memory usage.
For regular requests, collects results normally.
"""
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
...
...
@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing):
except
(
ValueError
,
IndexError
):
prompt_idx
=
result_idx
# Fallback to result_idx
short_prompts_results
[
prompt_idx
]
=
result
short_prompts_results
[
prompt_idx
]
=
cast
(
PoolingRequestOutput
,
result
)
# Finalize aggregated results
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
final_res_batch
:
list
[
PoolingRequestOutput
|
EmbeddingRequestOutput
]
=
[]
num_prompts
=
len
(
ctx
.
engine_prompts
)
for
prompt_idx
in
range
(
num_prompts
):
...
...
@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing):
f
"Failed to aggregate chunks for prompt
{
prompt_idx
}
"
)
elif
prompt_idx
in
short_prompts_results
:
final_res_batch
.
append
(
short_prompts_results
[
prompt_idx
])
final_res_batch
.
append
(
cast
(
PoolingRequestOutput
,
short_prompts_results
[
prompt_idx
])
)
else
:
return
self
.
create_error_response
(
f
"Result not found for prompt
{
prompt_idx
}
"
)
ctx
.
final_res_batch
=
final_res_batch
ctx
.
final_res_batch
=
cast
(
list
[
RequestOutput
|
PoolingRequestOutput
],
final_res_batch
)
return
None
except
Exception
as
e
:
return
self
.
create_error_response
(
str
(
e
))
class
OpenAIServingEmbedding
(
EmbeddingMixin
):
request_id_prefix
=
"embd"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
create_embedding
(
self
,
request
:
EmbeddingRequest
,
...
...
@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing):
raw_request
=
raw_request
,
model_name
=
model_name
,
request_id
=
request_id
,
chat_template
=
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
)
return
await
s
elf
.
handle
(
ctx
)
# type: ignore
[return-value]
return
await
s
uper
()
.
handle
(
ctx
)
# type: ignore
@
override
def
_create_pooling_params
(
self
,
ctx
:
Embedding
ServeContext
,
ctx
:
ServeContext
[
EmbeddingRequest
]
,
)
->
PoolingParams
|
ErrorResponse
:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
...
...
@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
return
pooling_params
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
return
await
super
().
_preprocess
(
ctx
)
\ No newline at end of file
vllm/entrypoints/utils.py
View file @
c721b814
...
...
@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks
from
vllm
import
envs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.inputs
import
EmbedsPrompt
,
TokensPrompt
from
vllm.logger
import
current_formatter_type
,
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
if
TYPE_CHECKING
:
...
...
@@ -34,15 +32,11 @@ if TYPE_CHECKING:
StreamOptions
,
)
from
vllm.entrypoints.openai.models.protocol
import
LoRAModulePath
from
vllm.entrypoints.openai.responses.protocol
import
(
ResponsesRequest
,
)
else
:
ChatCompletionRequest
=
object
CompletionRequest
=
object
StreamOptions
=
object
LoRAModulePath
=
object
ResponsesRequest
=
object
logger
=
init_logger
(
__name__
)
...
...
@@ -217,26 +211,11 @@ def _validate_truncation_size(
def
get_max_tokens
(
max_model_len
:
int
,
request
:
"CompletionRequest |
Chat
CompletionRequest
| ResponsesRequest
"
,
prompt
:
TokensPrompt
|
EmbedsPromp
t
,
request
:
"
Chat
CompletionRequest | CompletionRequest"
,
input_length
:
in
t
,
default_sampling_params
:
dict
,
)
->
int
:
# NOTE: Avoid isinstance() for better efficiency
max_tokens
:
int
|
None
=
None
if
max_tokens
is
None
:
# ChatCompletionRequest
max_tokens
=
getattr
(
request
,
"max_completion_tokens"
,
None
)
if
max_tokens
is
None
:
# ResponsesRequest
max_tokens
=
getattr
(
request
,
"max_output_tokens"
,
None
)
if
max_tokens
is
None
:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens
=
getattr
(
request
,
"max_tokens"
,
None
)
input_length
=
length_from_prompt_token_ids_or_embeds
(
prompt
.
get
(
"prompt_token_ids"
),
# type: ignore[arg-type]
prompt
.
get
(
"prompt_embeds"
),
# type: ignore[arg-type]
)
max_tokens
=
getattr
(
request
,
"max_completion_tokens"
,
None
)
or
request
.
max_tokens
default_max_tokens
=
max_model_len
-
input_length
max_output_tokens
=
current_platform
.
get_max_output_tokens
(
input_length
)
...
...
Prev
1
2
3
4
5
6
7
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment