Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c721b814
Commit
c721b814
authored
Feb 05, 2026
by
zhuwenwen
Browse files
sync v0.15.1
parent
d53fe7e5
Changes
328
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
380 additions
and
733 deletions
+380
-733
vllm/config/compilation.py
vllm/config/compilation.py
+1
-2
vllm/config/speculative.py
vllm/config/speculative.py
+0
-12
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+4
-98
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+4
-43
vllm/distributed/device_communicators/cpu_communicator.py
vllm/distributed/device_communicators/cpu_communicator.py
+7
-42
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+6
-38
vllm/distributed/device_communicators/mnnvl_compat.py
vllm/distributed/device_communicators/mnnvl_compat.py
+0
-13
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+7
-43
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+0
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+2
-24
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+0
-33
vllm/entrypoints/openai/chat_completion/serving.py
vllm/entrypoints/openai/chat_completion/serving.py
+45
-126
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+0
-11
vllm/entrypoints/openai/completion/serving.py
vllm/entrypoints/openai/completion/serving.py
+15
-2
vllm/entrypoints/openai/engine/protocol.py
vllm/entrypoints/openai/engine/protocol.py
+0
-4
vllm/entrypoints/openai/engine/serving.py
vllm/entrypoints/openai/engine/serving.py
+64
-35
vllm/entrypoints/openai/responses/serving.py
vllm/entrypoints/openai/responses/serving.py
+30
-60
vllm/entrypoints/pooling/classify/serving.py
vllm/entrypoints/pooling/classify/serving.py
+90
-58
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+102
-64
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+3
-24
No files found.
vllm/config/compilation.py
View file @
c721b814
...
@@ -280,10 +280,9 @@ class DynamicShapesConfig:
...
@@ -280,10 +280,9 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
"""
assume_32_bit_indexing
:
bool
=
Fals
e
assume_32_bit_indexing
:
bool
=
Tru
e
"""
"""
whether all tensor sizes can use 32 bit indexing.
whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
"""
"""
def
compute_hash
(
self
)
->
str
:
def
compute_hash
(
self
)
->
str
:
...
...
vllm/config/speculative.py
View file @
c721b814
...
@@ -34,7 +34,6 @@ MTPModelTypes = Literal[
...
@@ -34,7 +34,6 @@ MTPModelTypes = Literal[
"mimo_mtp"
,
"mimo_mtp"
,
"glm4_moe_mtp"
,
"glm4_moe_mtp"
,
"glm4_moe_lite_mtp"
,
"glm4_moe_lite_mtp"
,
"glm_ocr_mtp"
,
"ernie_mtp"
,
"ernie_mtp"
,
"exaone_moe_mtp"
,
"exaone_moe_mtp"
,
"qwen3_next_mtp"
,
"qwen3_next_mtp"
,
...
@@ -223,17 +222,6 @@ class SpeculativeConfig:
...
@@ -223,17 +222,6 @@ class SpeculativeConfig:
}
}
)
)
if
hf_config
.
architectures
[
0
]
==
"GlmOcrForConditionalGeneration"
:
hf_config
.
model_type
=
"glm_ocr_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
(
{
"num_hidden_layers"
:
0
,
"n_predict"
:
n_predict
,
"architectures"
:
[
"GlmOcrMTPModel"
],
}
)
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
hf_config
.
model_type
=
"ernie_mtp"
hf_config
.
model_type
=
"ernie_mtp"
if
hf_config
.
model_type
==
"ernie_mtp"
:
if
hf_config
.
model_type
==
"ernie_mtp"
:
...
...
vllm/distributed/device_communicators/all2all.py
View file @
c721b814
...
@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
...
@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return
buffer
return
buffer
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase):
...
@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase):
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
extra_tensors
is
not
None
:
raise
NotImplementedError
(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size
=
self
.
tp_group
.
world_size
if
is_sequence_parallel
else
1
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
cu_tokens_across_sp_cpu
=
dp_metadata
.
cu_tokens_across_sp
(
sp_size
)
hidden_states
=
self
.
naive_multicast
(
hidden_states
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_weights
=
self
.
naive_multicast
(
topk_weights
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_ids
=
self
.
naive_multicast
(
topk_ids
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
...
@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def
__init__
(
self
,
cpu_group
):
def
__init__
(
self
,
cpu_group
):
super
().
__init__
(
cpu_group
)
super
().
__init__
(
cpu_group
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
...
@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
sizes
=
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
assert
sizes
is
not
None
dist_group
=
get_ep_group
()
if
is_sequence_parallel
else
get_dp_group
()
assert
sizes
[
dist_group
.
rank_in_group
]
==
hidden_states
.
shape
[
0
]
tensors_to_gather
=
[
hidden_states
,
topk_weights
,
topk_ids
]
if
extra_tensors
is
not
None
:
tensors_to_gather
.
extend
(
extra_tensors
)
gathered_tensors
=
dist_group
.
all_gatherv
(
tensors_to_gather
,
dim
=
0
,
sizes
=
sizes
,
)
hidden_states
=
gathered_tensors
[
0
]
topk_weights
=
gathered_tensors
[
1
]
topk_ids
=
gathered_tensors
[
2
]
if
extra_tensors
is
None
:
return
hidden_states
,
topk_weights
,
topk_ids
return
hidden_states
,
topk_weights
,
topk_ids
,
gathered_tensors
[
3
:]
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
...
@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx
.
AllToAll
.
internode
if
self
.
internode
else
pplx
.
AllToAll
.
intranode
,
pplx
.
AllToAll
.
internode
if
self
.
internode
else
pplx
.
AllToAll
.
intranode
,
)
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase):
...
@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
raise
NotImplementedError
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
...
@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def
get_handle
(
self
,
kwargs
):
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
...
@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
raise
NotImplementedError
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/distributed/device_communicators/base_device_communicator.py
View file @
c721b814
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
threading
import
threading
from
typing
import
Any
from
weakref
import
WeakValueDictionary
from
weakref
import
WeakValueDictionary
import
torch
import
torch
...
@@ -63,32 +64,13 @@ class All2AllManagerBase:
...
@@ -63,32 +64,13 @@ class All2AllManagerBase:
# and reuse it for the same config.
# and reuse it for the same config.
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
)
->
Any
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
# - raise a clear error if extra_tensors is not supported.
...
@@ -298,7 +280,7 @@ class DeviceCommunicatorBase:
...
@@ -298,7 +280,7 @@ class DeviceCommunicatorBase:
for
module
in
moe_modules
:
for
module
in
moe_modules
:
module
.
maybe_init_modular_kernel
()
module
.
maybe_init_modular_kernel
()
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -312,29 +294,8 @@ class DeviceCommunicatorBase:
...
@@ -312,29 +294,8 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
This is a no-op in the base class.
"""
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
router_logits
,
extra_tensors
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
topk_weights
,
topk_ids
,
extra_tensors
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/distributed/device_communicators/cpu_communicator.py
View file @
c721b814
...
@@ -130,65 +130,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
...
@@ -130,65 +130,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
hidden_states
,
topk_weights
,
router_logits
,
topk_ids
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
hidden_states
,
is_sequence_parallel
is_sequence_parallel
,
)
)
return
hidden_states
class
_CPUSHMDistributed
:
class
_CPUSHMDistributed
:
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
c721b814
...
@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return
output_list
return
output_list
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -332,52 +332,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -332,52 +332,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
hidden_states
,
topk_weights
,
router_logits
,
topk_ids
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
hidden_states
,
is_sequence_parallel
is_sequence_parallel
,
)
)
return
hidden_states
vllm/distributed/device_communicators/mnnvl_compat.py
View file @
c721b814
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
...
@@ -24,15 +23,3 @@ class CustomCommunicator(CommBackend):
...
@@ -24,15 +23,3 @@ class CustomCommunicator(CommBackend):
gathered
=
[
None
]
*
self
.
Get_size
()
gathered
=
[
None
]
*
self
.
Get_size
()
dist
.
all_gather_object
(
gathered
,
data
,
group
=
self
.
_group
)
dist
.
all_gather_object
(
gathered
,
data
,
group
=
self
.
_group
)
return
gathered
return
gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def
bcast
(
self
,
data
:
Any
,
root
:
int
)
->
Any
:
raise
NotImplementedError
def
barrier
(
self
)
->
None
:
raise
NotImplementedError
def
Split
(
self
,
color
:
int
,
key
:
int
)
->
"CustomCommunicator"
:
return
self
vllm/distributed/device_communicators/xpu_communicator.py
View file @
c721b814
...
@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase):
...
@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase):
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
def
dispatch_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
hidden_states
,
topk_weights
,
router_logits
,
topk_ids
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
hidden_states
,
is_sequence_parallel
is_sequence_parallel
,
)
)
return
hidden_states
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
c721b814
...
@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class
NixlConnector
(
KVConnectorBase_V1
):
class
NixlConnector
(
KVConnectorBase_V1
):
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
...
...
vllm/distributed/parallel_state.py
View file @
c721b814
...
@@ -1000,7 +1000,7 @@ class GroupCoordinator:
...
@@ -1000,7 +1000,7 @@ class GroupCoordinator:
if
self
.
device_communicator
is
not
None
:
if
self
.
device_communicator
is
not
None
:
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -1011,7 +1011,7 @@ class GroupCoordinator:
...
@@ -1011,7 +1011,7 @@ class GroupCoordinator:
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
):
if
self
.
device_communicator
is
not
None
:
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
_router_logits
(
return
self
.
device_communicator
.
dispatch
(
# type: ignore[call-arg]
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
is_sequence_parallel
,
is_sequence_parallel
,
...
@@ -1020,28 +1020,6 @@ class GroupCoordinator:
...
@@ -1020,28 +1020,6 @@ class GroupCoordinator:
else
:
else
:
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
):
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
,
)
else
:
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/entrypoints/openai/api_server.py
View file @
c721b814
...
@@ -264,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None:
...
@@ -264,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None:
return
None
return
None
def
get_uvicorn_log_config
(
args
:
Namespace
)
->
dict
|
None
:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config
=
load_log_config
(
args
.
log_config_file
)
if
log_config
is
not
None
:
return
log_config
# If endpoints to filter are specified, create a config with the filter
if
args
.
disable_access_log_for_endpoints
:
from
vllm.logging_utils
import
create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths
=
[
p
.
strip
()
for
p
in
args
.
disable_access_log_for_endpoints
.
split
(
","
)
if
p
.
strip
()
]
return
create_uvicorn_log_config
(
excluded_paths
=
excluded_paths
,
log_level
=
args
.
uvicorn_log_level
,
)
return
None
class
AuthenticationMiddleware
:
class
AuthenticationMiddleware
:
"""
"""
Pure ASGI middleware that authenticates each request by checking
Pure ASGI middleware that authenticates each request by checking
...
...
vllm/entrypoints/openai/chat_completion/serving.py
View file @
c721b814
...
@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import (
...
@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage
,
DeltaMessage
,
DeltaToolCall
,
DeltaToolCall
,
ErrorResponse
,
ErrorResponse
,
FunctionCall
,
PromptTokenUsageInfo
,
PromptTokenUsageInfo
,
RequestResponseMetadata
,
RequestResponseMetadata
,
ToolCall
,
ToolCall
,
...
@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
...
@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from
vllm.entrypoints.openai.utils
import
maybe_filter_parallel_tool_calls
from
vllm.entrypoints.openai.utils
import
maybe_filter_parallel_tool_calls
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.parse
import
get_prompt_components
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
...
@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing):
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
self
.
enable_force_include_usage
=
enable_force_include_usage
self
.
enable_force_include_usage
=
enable_force_include_usage
self
.
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
self
.
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
if
self
.
model_config
.
hf_config
.
model_type
==
"kimi_k2"
:
self
.
tool_call_id_type
=
"kimi_k2"
else
:
self
.
tool_call_id_type
=
"random"
self
.
use_harmony
=
self
.
model_config
.
hf_config
.
model_type
==
"gpt_oss"
self
.
use_harmony
=
self
.
model_config
.
hf_config
.
model_type
==
"gpt_oss"
if
self
.
use_harmony
:
if
self
.
use_harmony
:
if
"stop_token_ids"
not
in
self
.
default_sampling_params
:
if
"stop_token_ids"
not
in
self
.
default_sampling_params
:
...
@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions
()
get_stop_tokens_for_assistant_actions
()
)
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides
=
getattr
(
self
.
model_config
,
"hf_overrides"
,
None
)
if
self
.
model_config
.
hf_text_config
.
model_type
==
"kimi_k2"
or
(
isinstance
(
hf_overrides
,
dict
)
and
hf_overrides
.
get
(
"model_type"
)
==
"kimi_k2"
):
self
.
tool_call_id_type
=
"kimi_k2"
else
:
self
.
tool_call_id_type
=
"random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
# Responses API instead.
...
@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
# type: ignore[arg-type]
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
# type: ignore[arg-type]
truncate_tool_call_ids
(
request
)
validate_request_params
(
request
)
validate_request_params
(
request
)
# Check if tool parsing is unavailable (common condition)
# Check if tool parsing is unavailable (common condition)
...
@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
_
,
_
=
self
.
_get_prompt_components
(
engine_prompt
)
# If we are creating sub requests for multiple prompts, ensure that they
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
# have unique request ids.
sub_request_id
=
(
sub_request_id
=
(
request_id
if
len
(
engine_prompts
)
==
1
else
f
"
{
request_id
}
_
{
i
}
"
request_id
if
len
(
engine_prompts
)
==
1
else
f
"
{
request_id
}
_
{
i
}
"
)
)
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
request
=
request
,
request
=
request
,
prompt
=
engine_prompt
,
input_length
=
len
(
engine_prompt
[
"prompt_token_ids"
])
,
default_sampling_params
=
self
.
default_sampling_params
,
default_sampling_params
=
self
.
default_sampling_params
,
)
)
...
@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
# Streaming response
tokenizer
=
self
.
renderer
.
tokenizer
tokenizer
=
self
.
renderer
.
tokenizer
assert
tokenizer
is
not
None
if
request
.
stream
:
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
return
self
.
chat_completion_stream_generator
(
...
@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
model_name
:
str
,
model_name
:
str
,
conversation
:
list
[
ConversationMessage
],
conversation
:
list
[
ConversationMessage
],
tokenizer
:
TokenizerLike
,
tokenizer
:
TokenizerLike
|
None
,
request_metadata
:
RequestResponseMetadata
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
from
vllm.tokenizers.mistral
import
MistralTokenizer
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
chunk_object_type
:
Final
=
"chat.completion.chunk"
chunk_object_type
:
Final
=
"chat.completion.chunk"
first_iteration
=
True
first_iteration
=
True
...
@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing):
)
)
reasoning_parser
=
self
.
reasoning_parser
(
reasoning_parser
=
self
.
reasoning_parser
(
tokenizer
,
tokenizer
,
chat_template_kwargs
=
chat_template_kwargs
or
{}
,
# type: ignore[call-arg]
chat_template_kwargs
=
chat_template_kwargs
,
# type: ignore[call-arg]
)
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logger
.
exception
(
"Error in reasoning parser creation."
)
logger
.
exception
(
"Error in reasoning parser creation."
)
...
@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing):
index
=
i
,
index
=
i
,
)
)
else
:
else
:
# Generate ID based on tokenizer type
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_id
=
MistralToolCall
.
generate_random_id
()
else
:
tool_call_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tool_choice_function_name
,
idx
=
history_tool_call_cnt
,
)
delta_tool_call
=
DeltaToolCall
(
delta_tool_call
=
DeltaToolCall
(
id
=
tool_call_id
,
id
=
make_
tool_call_id
()
,
type
=
"function"
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
function
=
DeltaFunctionCall
(
name
=
tool_choice_function_name
,
name
=
tool_choice_function_name
,
...
@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
model_name
:
str
,
model_name
:
str
,
conversation
:
list
[
ConversationMessage
],
conversation
:
list
[
ConversationMessage
],
tokenizer
:
TokenizerLike
,
tokenizer
:
TokenizerLike
|
None
,
request_metadata
:
RequestResponseMetadata
,
request_metadata
:
RequestResponseMetadata
,
)
->
ErrorResponse
|
ChatCompletionResponse
:
)
->
ErrorResponse
|
ChatCompletionResponse
:
from
vllm.tokenizers.mistral
import
MistralTokenizer
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
final_res
:
RequestOutput
|
None
=
None
final_res
:
RequestOutput
|
None
=
None
...
@@ -1543,85 +1524,39 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1543,85 +1524,39 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class
=
(
tool_call_class
=
(
MistralToolCall
if
isinstance
(
tokenizer
,
MistralTokenizer
)
else
ToolCall
MistralToolCall
if
isinstance
(
tokenizer
,
MistralTokenizer
)
else
ToolCall
)
)
if
self
.
use_harmony
:
if
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
(
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
,
tool_calls
=
tool_calls
if
tool_calls
else
[],
)
elif
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
(
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
)
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
)
and
request
.
tool_choice
!=
"required"
and
request
.
tool_choice
!=
"required"
):
):
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
)
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
)
# if the request uses tools and specified a tool choice
elif
(
elif
(
request
.
tool_choice
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
):
):
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
tool_call_class_items
=
[]
for
idx
,
tc
in
enumerate
(
tool_calls
):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if
tc
.
id
:
tool_call_class_items
.
append
(
tool_call_class
(
id
=
tc
.
id
,
function
=
tc
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_class_items
.
append
(
tool_call_class
(
function
=
tc
))
else
:
generated_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tc
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
)
tool_call_class_items
.
append
(
tool_call_class
(
id
=
generated_id
,
function
=
tc
)
)
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
message
=
ChatMessage
(
role
=
role
,
role
=
role
,
reasoning
=
reasoning
,
reasoning
=
reasoning
,
content
=
""
,
content
=
""
,
tool_calls
=
tool_call_class
_items
,
tool_calls
=
[
tool_call_class
(
function
=
tc
)
for
tc
in
tool_calls
]
,
)
)
elif
request
.
tool_choice
and
request
.
tool_choice
==
"required"
:
elif
request
.
tool_choice
and
request
.
tool_choice
==
"required"
:
tool_call_class_items
=
[]
tool_call_class_items
=
[]
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
for
idx
,
tool_call
in
enumerate
(
tool_calls
):
for
tool_call
in
tool_calls
:
# Use native ID if available,
tool_call_class_items
.
append
(
# otherwise generate ID with correct id_type
tool_call_class
(
if
tool_call
.
id
:
id
=
make_tool_call_id
(
tool_call_class_items
.
append
(
tool_call_class
(
id
=
tool_call
.
id
,
function
=
tool_call
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_class_items
.
append
(
tool_call_class
(
function
=
tool_call
)
)
else
:
generated_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
id_type
=
self
.
tool_call_id_type
,
func_name
=
tool_call
.
name
,
func_name
=
tool_call
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
idx
=
history_tool_call_cnt
,
)
)
,
tool_call_class_items
.
append
(
function
=
tool_call
,
tool_call_class
(
id
=
generated_id
,
function
=
tool_call
)
)
)
)
history_tool_call_cnt
+=
1
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
message
=
ChatMessage
(
role
=
role
,
role
=
role
,
...
@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls
# call. The same is not true for named function calls
auto_tools_called
=
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
auto_tools_called
=
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
if
tool_calls
:
if
tool_calls
:
tool_call_items
=
[]
for
idx
,
tc
in
enumerate
(
tool_calls
):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if
tc
.
id
:
tool_call_items
.
append
(
tool_call_class
(
id
=
tc
.
id
,
function
=
tc
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_items
.
append
(
tool_call_class
(
function
=
tc
))
else
:
generated_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tc
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
)
tool_call_items
.
append
(
tool_call_class
(
id
=
generated_id
,
function
=
tc
)
)
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
message
=
ChatMessage
(
role
=
role
,
role
=
role
,
reasoning
=
reasoning
,
reasoning
=
reasoning
,
content
=
content
,
content
=
content
,
tool_calls
=
tool_call_items
,
tool_calls
=
[
ToolCall
(
function
=
tc
,
type
=
"function"
,
)
for
tc
in
tool_calls
],
)
)
else
:
else
:
...
@@ -1784,11 +1701,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1784,11 +1701,13 @@ class OpenAIServingChat(OpenAIServing):
elif
choice
.
message
.
tool_calls
:
elif
choice
.
message
.
tool_calls
:
# For tool calls, log the function name and arguments
# For tool calls, log the function name and arguments
tool_call_descriptions
=
[]
tool_call_descriptions
=
[]
for
tc
in
choice
.
message
.
tool_calls
:
# type: ignore
for
tc
in
choice
.
message
.
tool_calls
:
function_call
:
FunctionCall
=
tc
.
function
# type: ignore
if
hasattr
(
tc
.
function
,
"name"
)
and
hasattr
(
tool_call_descriptions
.
append
(
tc
.
function
,
"arguments"
f
"
{
function_call
.
name
}
(
{
function_call
.
arguments
}
)"
):
)
tool_call_descriptions
.
append
(
f
"
{
tc
.
function
.
name
}
(
{
tc
.
function
.
arguments
}
)"
)
tool_calls_str
=
", "
.
join
(
tool_call_descriptions
)
tool_calls_str
=
", "
.
join
(
tool_call_descriptions
)
output_text
=
f
"[tool_calls:
{
tool_calls_str
}
]"
output_text
=
f
"[tool_calls:
{
tool_calls_str
}
]"
...
@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
# type: ignore[arg-type]
maybe_serialize_tool_calls
(
request
)
# Add system message.
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
# NOTE: In Chat Completion API, browsing is enabled by default
...
@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message.
# Add developer message.
if
request
.
tools
:
if
request
.
tools
:
dev_msg
=
get_developer_message
(
dev_msg
=
get_developer_message
(
tools
=
request
.
tools
if
should_include_tools
else
None
# type: ignore[arg-type]
tools
=
request
.
tools
if
should_include_tools
else
None
)
)
messages
.
append
(
dev_msg
)
messages
.
append
(
dev_msg
)
...
@@ -2009,4 +1928,4 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -2009,4 +1928,4 @@ class OpenAIServingChat(OpenAIServing):
if
request
.
cache_salt
is
not
None
:
if
request
.
cache_salt
is
not
None
:
engine_prompt
[
"cache_salt"
]
=
request
.
cache_salt
engine_prompt
[
"cache_salt"
]
=
request
.
cache_salt
return
messages
,
[
engine_prompt
]
return
messages
,
[
engine_prompt
]
\ No newline at end of file
vllm/entrypoints/openai/cli_args.py
View file @
c721b814
...
@@ -85,12 +85,6 @@ class FrontendArgs:
...
@@ -85,12 +85,6 @@ class FrontendArgs:
"""Log level for uvicorn."""
"""Log level for uvicorn."""
disable_uvicorn_access_log
:
bool
=
False
disable_uvicorn_access_log
:
bool
=
False
"""Disable uvicorn access log."""
"""Disable uvicorn access log."""
disable_access_log_for_endpoints
:
str
|
None
=
None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials
:
bool
=
False
allow_credentials
:
bool
=
False
"""Allow credentials."""
"""Allow credentials."""
allowed_origins
:
list
[
str
]
=
field
(
default_factory
=
lambda
:
[
"*"
])
allowed_origins
:
list
[
str
]
=
field
(
default_factory
=
lambda
:
[
"*"
])
...
@@ -250,11 +244,6 @@ class FrontendArgs:
...
@@ -250,11 +244,6 @@ class FrontendArgs:
del
frontend_kwargs
[
"middleware"
][
"nargs"
]
del
frontend_kwargs
[
"middleware"
][
"nargs"
]
frontend_kwargs
[
"middleware"
][
"default"
]
=
[]
frontend_kwargs
[
"middleware"
][
"default"
]
=
[]
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if
"nargs"
in
frontend_kwargs
[
"disable_access_log_for_endpoints"
]:
del
frontend_kwargs
[
"disable_access_log_for_endpoints"
][
"nargs"
]
# Special case: Tool call parser shows built-in options.
# Special case: Tool call parser shows built-in options.
valid_tool_parsers
=
list
(
ToolParserManager
.
list_registered
())
valid_tool_parsers
=
list
(
ToolParserManager
.
list_registered
())
parsers_str
=
","
.
join
(
valid_tool_parsers
)
parsers_str
=
","
.
join
(
valid_tool_parsers
)
...
...
vllm/entrypoints/openai/completion/serving.py
View file @
c721b814
...
@@ -163,12 +163,25 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -163,12 +163,25 @@ class OpenAIServingCompletion(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
prompt_token_ids
,
prompt_embeds
=
(
self
.
_get_prompt_components
(
engine_prompt
)
)
input_length
=
None
if
prompt_token_ids
is
not
None
:
input_length
=
len
(
prompt_token_ids
)
elif
prompt_embeds
is
not
None
:
input_length
=
len
(
prompt_embeds
)
else
:
raise
NotImplementedError
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
request
=
request
,
request
=
request
,
prompt
=
engine_prompt
,
input_length
=
input_length
,
default_sampling_params
=
self
.
default_sampling_params
,
default_sampling_params
=
self
.
default_sampling_params
,
)
)
...
...
vllm/entrypoints/openai/engine/protocol.py
View file @
c721b814
...
@@ -218,10 +218,6 @@ def get_logits_processors(
...
@@ -218,10 +218,6 @@ def get_logits_processors(
class
FunctionCall
(
OpenAIBaseModel
):
class
FunctionCall
(
OpenAIBaseModel
):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id
:
str
|
None
=
Field
(
default
=
None
,
exclude
=
True
)
name
:
str
name
:
str
arguments
:
str
arguments
:
str
...
...
vllm/entrypoints/openai/engine/serving.py
View file @
c721b814
...
@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import (
...
@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import (
from
vllm.entrypoints.pooling.classify.protocol
import
(
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
ClassificationChatRequest
,
ClassificationCompletionRequest
,
ClassificationCompletionRequest
,
ClassificationRequest
,
ClassificationResponse
,
ClassificationResponse
,
)
)
from
vllm.entrypoints.pooling.embed.protocol
import
(
from
vllm.entrypoints.pooling.embed.protocol
import
(
EmbeddingBytesResponse
,
EmbeddingChatRequest
,
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
EmbeddingCompletionRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponse
,
)
)
from
vllm.entrypoints.pooling.pooling.protocol
import
(
from
vllm.entrypoints.pooling.pooling.protocol
import
(
...
@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
...
@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest
,
TokenizeCompletionRequest
,
TokenizeResponse
,
TokenizeResponse
,
)
)
from
vllm.entrypoints.utils
import
(
from
vllm.entrypoints.utils
import
_validate_truncation_size
,
sanitize_message
_validate_truncation_size
,
get_max_tokens
,
sanitize_message
,
)
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.parse
import
(
from
vllm.inputs.parse
import
(
PromptComponents
,
get_prompt_components
,
get_prompt_components
,
is_explicit_encoder_decoder_prompt
,
is_explicit_encoder_decoder_prompt
,
)
)
...
@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = (
...
@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = (
CompletionResponse
CompletionResponse
|
ChatCompletionResponse
|
ChatCompletionResponse
|
EmbeddingResponse
|
EmbeddingResponse
|
EmbeddingBytesResponse
|
TranscriptionResponse
|
TranscriptionResponse
|
TokenizeResponse
|
TokenizeResponse
|
PoolingResponse
|
PoolingResponse
...
@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
...
@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@
dataclass
(
kw_only
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ServeContext
(
Generic
[
RequestT
]):
class
RequestProcessingMixin
:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts
:
list
[
TokensPrompt
]
|
None
=
field
(
default_factory
=
list
)
@
dataclass
(
kw_only
=
True
)
class
ResponseGenerationMixin
:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator
:
(
AsyncGenerator
[
tuple
[
int
,
RequestOutput
|
PoolingRequestOutput
],
None
]
|
None
)
=
None
final_res_batch
:
list
[
RequestOutput
|
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ServeContext
(
RequestProcessingMixin
,
ResponseGenerationMixin
,
Generic
[
RequestT
]):
request
:
RequestT
request
:
RequestT
raw_request
:
Request
|
None
=
None
raw_request
:
Request
|
None
=
None
model_name
:
str
model_name
:
str
request_id
:
str
request_id
:
str
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
LoRARequest
|
None
=
None
lora_request
:
LoRARequest
|
None
=
None
engine_prompts
:
list
[
TokensPrompt
]
|
None
=
None
result_generator
:
AsyncGenerator
[
tuple
[
int
,
PoolingRequestOutput
],
None
]
|
None
=
(
None
)
final_res_batch
:
list
[
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ClassificationServeContext
(
ServeContext
[
ClassificationRequest
]):
pass
@
dataclass
(
kw_only
=
True
)
class
EmbeddingServeContext
(
ServeContext
[
EmbeddingRequest
]):
chat_template
:
str
|
None
=
None
chat_template_content_format
:
ChatTemplateContentFormatOption
class
OpenAIServing
:
class
OpenAIServing
:
...
@@ -578,7 +605,10 @@ class OpenAIServing:
...
@@ -578,7 +605,10 @@ class OpenAIServing:
self
,
self
,
ctx
:
ServeContext
,
ctx
:
ServeContext
,
)
->
AnyResponse
|
ErrorResponse
:
)
->
AnyResponse
|
ErrorResponse
:
async
for
response
in
self
.
_pipeline
(
ctx
):
generation
:
AsyncGenerator
[
AnyResponse
|
ErrorResponse
,
None
]
generation
=
self
.
_pipeline
(
ctx
)
async
for
response
in
generation
:
return
response
return
response
return
self
.
create_error_response
(
"No response yielded from pipeline"
)
return
self
.
create_error_response
(
"No response yielded from pipeline"
)
...
@@ -637,7 +667,9 @@ class OpenAIServing:
...
@@ -637,7 +667,9 @@ class OpenAIServing:
ctx
:
ServeContext
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
)
->
ErrorResponse
|
None
:
"""Schedule the request and get the result generator."""
"""Schedule the request and get the result generator."""
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]
]
=
[]
try
:
try
:
trace_headers
=
(
trace_headers
=
(
...
@@ -691,7 +723,7 @@ class OpenAIServing:
...
@@ -691,7 +723,7 @@ class OpenAIServing:
return
self
.
create_error_response
(
"Engine prompts not available"
)
return
self
.
create_error_response
(
"Engine prompts not available"
)
num_prompts
=
len
(
ctx
.
engine_prompts
)
num_prompts
=
len
(
ctx
.
engine_prompts
)
final_res_batch
:
list
[
PoolingRequestOutput
|
None
]
final_res_batch
:
list
[
RequestOutput
|
PoolingRequestOutput
|
None
]
final_res_batch
=
[
None
]
*
num_prompts
final_res_batch
=
[
None
]
*
num_prompts
if
ctx
.
result_generator
is
None
:
if
ctx
.
result_generator
is
None
:
...
@@ -979,7 +1011,7 @@ class OpenAIServing:
...
@@ -979,7 +1011,7 @@ class OpenAIServing:
def
_validate_input
(
def
_validate_input
(
self
,
self
,
request
:
objec
t
,
request
:
AnyReques
t
,
input_ids
:
list
[
int
],
input_ids
:
list
[
int
],
input_text
:
str
,
input_text
:
str
,
)
->
TokensPrompt
:
)
->
TokensPrompt
:
...
@@ -1290,7 +1322,7 @@ class OpenAIServing:
...
@@ -1290,7 +1322,7 @@ class OpenAIServing:
priority
:
int
=
0
,
priority
:
int
=
0
,
**
kwargs
,
**
kwargs
,
):
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
_
,
_
=
self
.
_
get_prompt_components
(
engine_prompt
)
orig_priority
=
priority
orig_priority
=
priority
sub_request
=
0
sub_request
=
0
...
@@ -1341,12 +1373,10 @@ class OpenAIServing:
...
@@ -1341,12 +1373,10 @@ class OpenAIServing:
# yield context
# yield context
# Create inputs for the next turn.
# Create inputs for the next turn.
# Render the next prompt token ids
and update sampling_params
.
# Render the next prompt token ids.
if
isinstance
(
context
,
(
HarmonyContext
,
StreamingHarmonyContext
)):
if
isinstance
(
context
,
(
HarmonyContext
,
StreamingHarmonyContext
)):
token_ids
=
context
.
render_for_completion
()
prompt_token_ids
=
context
.
render_for_completion
()
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
token_ids
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
)
sampling_params
.
max_tokens
=
self
.
max_model_len
-
len
(
token_ids
)
elif
isinstance
(
context
,
ParsableContext
):
elif
isinstance
(
context
,
ParsableContext
):
engine_prompts
=
await
self
.
_render_next_turn
(
engine_prompts
=
await
self
.
_render_next_turn
(
context
.
request
,
context
.
request
,
...
@@ -1358,19 +1388,19 @@ class OpenAIServing:
...
@@ -1358,19 +1388,19 @@ class OpenAIServing:
context
.
chat_template_content_format
,
context
.
chat_template_content_format
,
)
)
engine_prompt
=
engine_prompts
[
0
]
engine_prompt
=
engine_prompts
[
0
]
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
_
,
_
=
self
.
_get_prompt_components
(
engine_prompt
)
sampling_params
.
max_tokens
=
get_max_tokens
(
self
.
max_model_len
,
context
.
request
,
engine_prompt
,
self
.
default_sampling_params
,
# type: ignore
)
# Update the sampling params.
sampling_params
.
max_tokens
=
self
.
max_model_len
-
len
(
engine_prompt
[
"prompt_token_ids"
]
)
# OPTIMIZATION
# OPTIMIZATION
priority
=
orig_priority
-
1
priority
=
orig_priority
-
1
sub_request
+=
1
sub_request
+=
1
def
_get_prompt_components
(
self
,
prompt
:
PromptType
)
->
PromptComponents
:
return
get_prompt_components
(
prompt
)
def
_log_inputs
(
def
_log_inputs
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
@@ -1381,7 +1411,7 @@ class OpenAIServing:
...
@@ -1381,7 +1411,7 @@ class OpenAIServing:
if
self
.
request_logger
is
None
:
if
self
.
request_logger
is
None
:
return
return
prompt
,
prompt_token_ids
,
prompt_embeds
=
get_prompt_components
(
inputs
)
prompt
,
prompt_token_ids
,
prompt_embeds
=
self
.
_
get_prompt_components
(
inputs
)
self
.
request_logger
.
log_inputs
(
self
.
request_logger
.
log_inputs
(
request_id
,
request_id
,
...
@@ -1495,7 +1525,6 @@ class OpenAIServing:
...
@@ -1495,7 +1525,6 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls.
# extract_tool_calls() returns a list of tool calls.
function_calls
.
extend
(
function_calls
.
extend
(
FunctionCall
(
FunctionCall
(
id
=
tool_call
.
id
,
name
=
tool_call
.
function
.
name
,
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
arguments
=
tool_call
.
function
.
arguments
,
)
)
...
@@ -1548,4 +1577,4 @@ def clamp_prompt_logprobs(
...
@@ -1548,4 +1577,4 @@ def clamp_prompt_logprobs(
for
logprob_values
in
logprob_dict
.
values
():
for
logprob_values
in
logprob_dict
.
values
():
if
logprob_values
.
logprob
==
float
(
"-inf"
):
if
logprob_values
.
logprob
==
float
(
"-inf"
):
logprob_values
.
logprob
=
-
9999.0
logprob_values
.
logprob
=
-
9999.0
return
prompt_logprobs
return
prompt_logprobs
\ No newline at end of file
vllm/entrypoints/openai/responses/serving.py
View file @
c721b814
...
@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient
...
@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient
from
vllm.entrypoints.chat_utils
import
(
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
ChatCompletionMessageParam
,
ChatTemplateContentFormatOption
,
ChatTemplateContentFormatOption
,
make_tool_call_id
,
)
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.mcp.tool_server
import
ToolServer
from
vllm.entrypoints.mcp.tool_server
import
ToolServer
...
@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import (
...
@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types
,
extract_tool_types
,
should_continue_final_message
,
should_continue_final_message
,
)
)
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing):
self
.
default_sampling_params
[
"stop_token_ids"
].
extend
(
self
.
default_sampling_params
[
"stop_token_ids"
].
extend
(
get_stop_tokens_for_assistant_actions
()
get_stop_tokens_for_assistant_actions
()
)
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides
=
getattr
(
self
.
model_config
,
"hf_overrides"
,
None
)
if
self
.
model_config
.
hf_text_config
.
model_type
==
"kimi_k2"
or
(
isinstance
(
hf_overrides
,
dict
)
and
hf_overrides
.
get
(
"model_type"
)
==
"kimi_k2"
):
self
.
tool_call_id_type
=
"kimi_k2"
else
:
self
.
tool_call_id_type
=
"random"
self
.
enable_auto_tools
=
enable_auto_tools
self
.
enable_auto_tools
=
enable_auto_tools
# set up tool use
# set up tool use
self
.
tool_parser
=
self
.
_get_tool_parser
(
self
.
tool_parser
=
self
.
_get_tool_parser
(
...
@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing):
if
maybe_error
is
not
None
:
if
maybe_error
is
not
None
:
return
maybe_error
return
maybe_error
default_max_tokens
=
get_max_tokens
(
default_max_tokens
=
self
.
max_model_len
-
len
(
self
.
max_model_len
,
engine_prompt
[
"prompt_token_ids"
]
request
,
engine_prompt
,
self
.
default_sampling_params
,
)
)
sampling_params
=
request
.
to_sampling_params
(
sampling_params
=
request
.
to_sampling_params
(
...
@@ -970,28 +954,25 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -970,28 +954,25 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools
=
self
.
enable_auto_tools
,
enable_auto_tools
=
self
.
enable_auto_tools
,
tool_parser_cls
=
self
.
tool_parser
,
tool_parser_cls
=
self
.
tool_parser
,
)
)
if
content
:
if
content
or
(
self
.
use_harmony
and
tool_calls
):
output_text
=
ResponseOutputText
(
res_text_part
=
None
text
=
content
,
if
content
:
annotations
=
[],
# TODO
res_text_part
=
ResponseOutputText
(
type
=
"output_text"
,
text
=
content
,
logprobs
=
(
annotations
=
[],
# TODO
self
.
_create_response_logprobs
(
type
=
"output_text"
,
token_ids
=
final_output
.
token_ids
,
logprobs
=
(
logprobs
=
final_output
.
logprobs
,
self
.
_create_response_logprobs
(
tokenizer
=
tokenizer
,
token_ids
=
final_output
.
token_ids
,
top_logprobs
=
request
.
top_logprobs
,
logprobs
=
final_output
.
logprobs
,
)
tokenizer
=
tokenizer
,
if
request
.
is_include_output_logprobs
()
top_logprobs
=
request
.
top_logprobs
,
else
None
)
),
if
request
.
is_include_output_logprobs
()
)
else
None
),
)
message_item
=
ResponseOutputMessage
(
message_item
=
ResponseOutputMessage
(
id
=
f
"msg_
{
random_uuid
()
}
"
,
id
=
f
"msg_
{
random_uuid
()
}
"
,
content
=
[
res_text_part
]
if
res_text_part
else
[
],
content
=
[
output_text
],
role
=
"assistant"
,
role
=
"assistant"
,
status
=
"completed"
,
status
=
"completed"
,
type
=
"message"
,
type
=
"message"
,
...
@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing):
if
message_item
:
if
message_item
:
outputs
.
append
(
message_item
)
outputs
.
append
(
message_item
)
if
tool_calls
:
if
tool_calls
:
# We use a simple counter for history_tool_call_count because
tool_call_items
=
[
# we don't track the history of tool calls in the Responses API yet.
ResponseFunctionToolCall
(
# This means that the tool call index will start from 0 for each
id
=
f
"fc_
{
random_uuid
()
}
"
,
# request.
call_id
=
f
"call_
{
random_uuid
()
}
"
,
tool_call_items
=
[]
type
=
"function_call"
,
for
history_tool_call_cnt
,
tool_call
in
enumerate
(
tool_calls
):
status
=
"completed"
,
tool_call_items
.
append
(
name
=
tool_call
.
name
,
ResponseFunctionToolCall
(
arguments
=
tool_call
.
arguments
,
id
=
f
"fc_
{
random_uuid
()
}
"
,
call_id
=
tool_call
.
id
if
tool_call
.
id
else
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tool_call
.
name
,
idx
=
history_tool_call_cnt
,
),
type
=
"function_call"
,
status
=
"completed"
,
name
=
tool_call
.
name
,
arguments
=
tool_call
.
arguments
,
)
)
)
for
tool_call
in
tool_calls
]
outputs
.
extend
(
tool_call_items
)
outputs
.
extend
(
tool_call_items
)
return
outputs
return
outputs
...
@@ -2589,4 +2559,4 @@ class OpenAIServingResponses(OpenAIServing):
...
@@ -2589,4 +2559,4 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number
=-
1
,
sequence_number
=-
1
,
response
=
final_response
,
response
=
final_response
,
)
)
)
)
\ No newline at end of file
vllm/entrypoints/pooling/classify/serving.py
View file @
c721b814
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Final
,
cast
from
typing
import
cast
import
jinja2
import
jinja2
import
numpy
as
np
import
numpy
as
np
...
@@ -11,8 +11,18 @@ from fastapi import Request
...
@@ -11,8 +11,18 @@ from fastapi import Request
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
,
ServeContext
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
(
ClassificationServeContext
,
OpenAIServing
,
ServeContext
,
)
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.classify.protocol
import
(
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
ClassificationChatRequest
,
...
@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams
...
@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ClassificationServeContext
=
ServeContext
[
ClassificationRequest
]
class
ClassificationMixin
(
OpenAIServing
):
chat_template
:
str
|
None
chat_template_content_format
:
ChatTemplateContentFormatOption
class
ServingClassification
(
OpenAIServing
):
trust_request_chat_template
:
bool
request_id_prefix
=
"classify"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
_preprocess
(
async
def
_preprocess
(
self
,
self
,
ctx
:
Classification
ServeContext
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
)
->
ErrorResponse
|
None
:
"""
"""
Process classification inputs: tokenize text, resolve adapters,
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
and prepare model-specific inputs.
"""
"""
ctx
=
cast
(
ClassificationServeContext
,
ctx
)
try
:
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
request_obj
=
ctx
.
request
if
isinstance
(
ctx
.
request
,
ClassificationChatRequest
):
if
isinstance
(
request_obj
,
ClassificationChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
chat_request
=
request_obj
request_chat_template
=
ctx
.
request
.
chat_template
,
messages
=
chat_request
.
messages
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
getattr
(
trust_request_chat_template
=
self
.
trust_request_chat_template
,
self
,
"trust_request_chat_template"
,
False
,
)
)
if
error_check_ret
:
ret
=
self
.
_validate_chat_template
(
return
error_check_ret
request_chat_template
=
chat_request
.
chat_template
,
chat_template_kwargs
=
chat_request
.
chat_template_kwargs
,
trust_request_chat_template
=
trust_request_chat_template
,
)
if
ret
:
return
ret
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
c
tx
.
request
,
c
ast
(
ChatCompletionRequest
,
chat_
request
)
,
self
.
renderer
,
self
.
renderer
,
ctx
.
request
.
messages
,
messages
,
chat_template
=
ctx
.
request
.
chat_template
or
self
.
chat_template
,
chat_template
=
(
chat_template_content_format
=
self
.
chat_template_content_format
,
chat_request
.
chat_template
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
or
getattr
(
self
,
"chat_template"
,
None
)
continue_final_message
=
ctx
.
request
.
continue_final_message
,
),
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
chat_template_content_format
=
cast
(
ChatTemplateContentFormatOption
,
getattr
(
self
,
"chat_template_content_format"
,
"auto"
),
),
add_generation_prompt
=
chat_request
.
add_generation_prompt
,
continue_final_message
=
chat_request
.
continue_final_message
,
add_special_tokens
=
chat_request
.
add_special_tokens
,
)
)
ctx
.
engine_prompts
=
engine_prompts
ctx
.
engine_prompts
=
engine_prompts
elif
isinstance
(
ctx
.
request
,
ClassificationCompletionRequest
):
elif
isinstance
(
request_obj
,
ClassificationCompletionRequest
):
input_data
=
ctx
.
request
.
input
completion_request
=
request_obj
input_data
=
completion_request
.
input
if
input_data
in
(
None
,
""
):
if
input_data
in
(
None
,
""
):
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"Input or messages must be provided"
,
"Input or messages must be provided"
,
...
@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing):
...
@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing):
prompt_input
=
cast
(
str
|
list
[
str
],
input_data
)
prompt_input
=
cast
(
str
|
list
[
str
],
input_data
)
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
prompt_input
,
prompt_or_prompts
=
prompt_input
,
config
=
self
.
_build_render_config
(
c
tx
.
request
),
config
=
self
.
_build_render_config
(
c
ompletion_
request
),
)
)
else
:
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
self
.
create_error_response
(
"Invalid classification request type"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
return
None
return
None
...
@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing):
...
@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing):
def
_build_response
(
def
_build_response
(
self
,
self
,
ctx
:
Classification
ServeContext
,
ctx
:
ServeContext
,
)
->
ClassificationResponse
|
ErrorResponse
:
)
->
ClassificationResponse
|
ErrorResponse
:
"""
"""
Convert model outputs to a formatted classification response
Convert model outputs to a formatted classification response
with probabilities and labels.
with probabilities and labels.
"""
"""
id2label
=
getattr
(
self
.
model_config
.
hf_config
,
"id2label"
,
{})
ctx
=
cast
(
ClassificationServeContext
,
ctx
)
items
:
list
[
ClassificationData
]
=
[]
items
:
list
[
ClassificationData
]
=
[]
num_prompt_tokens
=
0
num_prompt_tokens
=
0
...
@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing):
...
@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing):
probs
=
classify_res
.
probs
probs
=
classify_res
.
probs
predicted_index
=
int
(
np
.
argmax
(
probs
))
predicted_index
=
int
(
np
.
argmax
(
probs
))
label
=
id2label
.
get
(
predicted_index
)
label
=
getattr
(
self
.
model_config
.
hf_config
,
"id2label"
,
{}).
get
(
predicted_index
)
item
=
ClassificationData
(
item
=
ClassificationData
(
index
=
idx
,
index
=
idx
,
...
@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing):
...
@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
)
class
ServingClassification
(
ClassificationMixin
):
request_id_prefix
=
"classify"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
create_classify
(
async
def
create_classify
(
self
,
self
,
request
:
ClassificationRequest
,
request
:
ClassificationRequest
,
...
@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing):
...
@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing):
request_id
=
request_id
,
request_id
=
request_id
,
)
)
return
await
s
elf
.
handle
(
ctx
)
# type: ignore
[return-value]
return
await
s
uper
()
.
handle
(
ctx
)
# type: ignore
def
_create_pooling_params
(
def
_create_pooling_params
(
self
,
self
,
ctx
:
Classification
ServeContext
,
ctx
:
ServeContext
[
Classification
Request
]
,
)
->
PoolingParams
|
ErrorResponse
:
)
->
PoolingParams
|
ErrorResponse
:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
if
isinstance
(
pooling_params
,
ErrorResponse
):
...
@@ -198,4 +230,4 @@ class ServingClassification(OpenAIServing):
...
@@ -198,4 +230,4 @@ class ServingClassification(OpenAIServing):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
return
pooling_params
return
pooling_params
\ No newline at end of file
vllm/entrypoints/pooling/embed/serving.py
View file @
c721b814
...
@@ -6,13 +6,21 @@ from typing import Any, Final, cast
...
@@ -6,13 +6,21 @@ from typing import Any, Final, cast
import
torch
import
torch
from
fastapi
import
Request
from
fastapi
import
Request
from
typing_extensions
import
assert_never
from
fastapi.responses
import
Response
from
typing_extensions
import
assert_never
,
override
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.engine.protocol
import
(
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
,
ServeContext
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
(
EmbeddingServeContext
,
OpenAIServing
,
ServeContext
,
)
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.embed.protocol
import
(
from
vllm.entrypoints.pooling.embed.protocol
import
(
EmbeddingBytesResponse
,
EmbeddingBytesResponse
,
...
@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import (
...
@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import (
from
vllm.entrypoints.renderer
import
RenderConfig
from
vllm.entrypoints.renderer
import
RenderConfig
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
PoolingOutput
,
PoolingRequestOutput
,
RequestOutput
,
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.collection_utils
import
chunk_list
from
vllm.utils.collection_utils
import
chunk_list
from
vllm.utils.serial_utils
import
(
from
vllm.utils.serial_utils
import
(
EmbedDType
,
EncodingFormat
,
Endianness
,
encode_pooling_bytes
,
encode_pooling_bytes
,
encode_pooling_output
,
encode_pooling_output
,
)
)
...
@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import (
...
@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import (
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
EmbeddingServeContext
=
ServeContext
[
EmbeddingRequest
]
class
EmbeddingMixin
(
OpenAIServing
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
class
OpenAIServingEmbedding
(
OpenAIServing
):
request_id_prefix
=
"embd"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
pooler_config
=
self
.
model_config
.
pooler_config
pooler_config
=
self
.
model_config
.
pooler_config
...
@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing):
else
None
else
None
)
)
@
override
async
def
_preprocess
(
async
def
_preprocess
(
self
,
self
,
ctx
:
Embedding
ServeContext
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
)
->
ErrorResponse
|
None
:
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
ctx
.
engine_prompts
=
await
self
.
_preprocess_chat
(
_
,
ctx
.
engine_prompts
=
await
self
.
_preprocess_chat
(
ctx
.
request
,
ctx
.
request
,
self
.
renderer
,
self
.
renderer
,
ctx
.
request
.
messages
,
ctx
.
request
.
messages
,
chat_template
=
ctx
.
request
.
chat_template
or
self
.
chat_template
,
chat_template
=
ctx
.
request
.
chat_template
or
ctx
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
chat_template_content_format
=
ctx
.
chat_template_content_format
,
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
continue_final_message
=
ctx
.
request
.
continue_final_message
,
continue_final_message
=
ctx
.
request
.
continue_final_message
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
)
el
if
isinstance
(
ctx
.
request
,
EmbeddingCompletionRequest
)
:
el
se
:
renderer
=
self
.
_get_completion_renderer
()
renderer
=
self
.
_get_completion_renderer
()
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
ctx
.
request
.
input
,
prompt_or_prompts
=
ctx
.
request
.
input
,
config
=
self
.
_build_render_config
(
ctx
.
request
),
config
=
self
.
_build_render_config
(
ctx
.
request
),
)
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
None
return
None
except
(
ValueError
,
TypeError
)
as
e
:
except
(
ValueError
,
TypeError
)
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
...
@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
)
@
override
def
_build_response
(
def
_build_response
(
self
,
self
,
ctx
:
Embedding
ServeContext
,
ctx
:
ServeContext
,
)
->
EmbeddingResponse
|
EmbeddingBytes
Response
|
ErrorResponse
:
)
->
EmbeddingResponse
|
Response
|
ErrorResponse
:
final_res_batch_checked
=
ctx
.
final_res_batch
final_res_batch_checked
=
cast
(
list
[
PoolingRequestOutput
],
ctx
.
final_res_batch
)
encoding_format
=
ctx
.
request
.
encoding_format
encoding_format
:
EncodingFormat
=
ctx
.
request
.
encoding_format
embed_dtype
=
ctx
.
request
.
embed_dtype
embed_dtype
:
EmbedDType
=
ctx
.
request
.
embed_dtype
endianness
=
ctx
.
request
.
endianness
endianness
:
Endianness
=
ctx
.
request
.
endianness
def
encode_float_base64
():
def
encode_float_base64
():
items
:
list
[
EmbeddingResponseData
]
=
[]
items
:
list
[
EmbeddingResponseData
]
=
[]
...
@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing):
self
,
self
,
ctx
:
EmbeddingServeContext
,
ctx
:
EmbeddingServeContext
,
token_ids
:
list
[
int
],
token_ids
:
list
[
int
],
pooling_params
:
PoolingParams
,
pooling_params
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
trace_headers
,
prompt_idx
:
int
,
prompt_idx
:
int
,
)
->
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]:
)
->
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]:
"""Process a single prompt using chunked processing."""
"""Process a single prompt using chunked processing."""
...
@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing):
def
_validate_input
(
def
_validate_input
(
self
,
self
,
request
:
object
,
request
,
input_ids
:
list
[
int
],
input_ids
:
list
[
int
],
input_text
:
str
,
input_text
:
str
,
)
->
TokensPrompt
:
)
->
TokensPrompt
:
...
@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params
:
PoolingParams
,
pooling_params
:
PoolingParams
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
prompt_index
:
int
,
prompt_index
:
int
,
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]:
"""Create a generator for a single prompt using standard processing."""
"""Create a generator for a single prompt using standard processing."""
request_id_item
=
f
"
{
ctx
.
request_id
}
-
{
prompt_index
}
"
request_id_item
=
f
"
{
ctx
.
request_id
}
-
{
prompt_index
}
"
...
@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing):
priority
=
getattr
(
ctx
.
request
,
"priority"
,
0
),
priority
=
getattr
(
ctx
.
request
,
"priority"
,
0
),
)
)
@
override
async
def
_prepare_generators
(
async
def
_prepare_generators
(
self
,
self
,
ctx
:
ServeContext
,
ctx
:
ServeContext
,
...
@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing):
return
await
super
().
_prepare_generators
(
ctx
)
return
await
super
().
_prepare_generators
(
ctx
)
# Custom logic for chunked processing
# Custom logic for chunked processing
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]
]
=
[]
try
:
try
:
trace_headers
=
(
trace_headers
=
(
...
@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
@
override
async
def
_collect_batch
(
async
def
_collect_batch
(
self
,
self
,
ctx
:
Embedding
ServeContext
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
)
->
ErrorResponse
|
None
:
"""Collect and aggregate batch results
"""Collect and aggregate batch results
with support for chunked processing.
with support for chunked processing.
...
@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing):
minimize memory usage.
minimize memory usage.
For regular requests, collects results normally.
For regular requests, collects results normally.
"""
"""
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
try
:
if
ctx
.
engine_prompts
is
None
:
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
return
self
.
create_error_response
(
"Engine prompts not available"
)
...
@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing):
except
(
ValueError
,
IndexError
):
except
(
ValueError
,
IndexError
):
prompt_idx
=
result_idx
# Fallback to result_idx
prompt_idx
=
result_idx
# Fallback to result_idx
short_prompts_results
[
prompt_idx
]
=
result
short_prompts_results
[
prompt_idx
]
=
cast
(
PoolingRequestOutput
,
result
)
# Finalize aggregated results
# Finalize aggregated results
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
final_res_batch
:
list
[
PoolingRequestOutput
|
EmbeddingRequestOutput
]
=
[]
num_prompts
=
len
(
ctx
.
engine_prompts
)
num_prompts
=
len
(
ctx
.
engine_prompts
)
for
prompt_idx
in
range
(
num_prompts
):
for
prompt_idx
in
range
(
num_prompts
):
...
@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing):
f
"Failed to aggregate chunks for prompt
{
prompt_idx
}
"
f
"Failed to aggregate chunks for prompt
{
prompt_idx
}
"
)
)
elif
prompt_idx
in
short_prompts_results
:
elif
prompt_idx
in
short_prompts_results
:
final_res_batch
.
append
(
short_prompts_results
[
prompt_idx
])
final_res_batch
.
append
(
cast
(
PoolingRequestOutput
,
short_prompts_results
[
prompt_idx
])
)
else
:
else
:
return
self
.
create_error_response
(
return
self
.
create_error_response
(
f
"Result not found for prompt
{
prompt_idx
}
"
f
"Result not found for prompt
{
prompt_idx
}
"
)
)
ctx
.
final_res_batch
=
final_res_batch
ctx
.
final_res_batch
=
cast
(
list
[
RequestOutput
|
PoolingRequestOutput
],
final_res_batch
)
return
None
return
None
except
Exception
as
e
:
except
Exception
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
class
OpenAIServingEmbedding
(
EmbeddingMixin
):
request_id_prefix
=
"embd"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
create_embedding
(
async
def
create_embedding
(
self
,
self
,
request
:
EmbeddingRequest
,
request
:
EmbeddingRequest
,
...
@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing):
raw_request
=
raw_request
,
raw_request
=
raw_request
,
model_name
=
model_name
,
model_name
=
model_name
,
request_id
=
request_id
,
request_id
=
request_id
,
chat_template
=
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
)
)
return
await
s
elf
.
handle
(
ctx
)
# type: ignore
[return-value]
return
await
s
uper
()
.
handle
(
ctx
)
# type: ignore
@
override
def
_create_pooling_params
(
def
_create_pooling_params
(
self
,
self
,
ctx
:
Embedding
ServeContext
,
ctx
:
ServeContext
[
EmbeddingRequest
]
,
)
->
PoolingParams
|
ErrorResponse
:
)
->
PoolingParams
|
ErrorResponse
:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
if
isinstance
(
pooling_params
,
ErrorResponse
):
...
@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing):
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
return
pooling_params
return
pooling_params
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
return
await
super
().
_preprocess
(
ctx
)
\ No newline at end of file
vllm/entrypoints/utils.py
View file @
c721b814
...
@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks
...
@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks
from
vllm
import
envs
from
vllm
import
envs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.inputs
import
EmbedsPrompt
,
TokensPrompt
from
vllm.logger
import
current_formatter_type
,
init_logger
from
vllm.logger
import
current_formatter_type
,
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -34,15 +32,11 @@ if TYPE_CHECKING:
...
@@ -34,15 +32,11 @@ if TYPE_CHECKING:
StreamOptions
,
StreamOptions
,
)
)
from
vllm.entrypoints.openai.models.protocol
import
LoRAModulePath
from
vllm.entrypoints.openai.models.protocol
import
LoRAModulePath
from
vllm.entrypoints.openai.responses.protocol
import
(
ResponsesRequest
,
)
else
:
else
:
ChatCompletionRequest
=
object
ChatCompletionRequest
=
object
CompletionRequest
=
object
CompletionRequest
=
object
StreamOptions
=
object
StreamOptions
=
object
LoRAModulePath
=
object
LoRAModulePath
=
object
ResponsesRequest
=
object
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -217,26 +211,11 @@ def _validate_truncation_size(
...
@@ -217,26 +211,11 @@ def _validate_truncation_size(
def
get_max_tokens
(
def
get_max_tokens
(
max_model_len
:
int
,
max_model_len
:
int
,
request
:
"CompletionRequest |
Chat
CompletionRequest
| ResponsesRequest
"
,
request
:
"
Chat
CompletionRequest | CompletionRequest"
,
prompt
:
TokensPrompt
|
EmbedsPromp
t
,
input_length
:
in
t
,
default_sampling_params
:
dict
,
default_sampling_params
:
dict
,
)
->
int
:
)
->
int
:
# NOTE: Avoid isinstance() for better efficiency
max_tokens
=
getattr
(
request
,
"max_completion_tokens"
,
None
)
or
request
.
max_tokens
max_tokens
:
int
|
None
=
None
if
max_tokens
is
None
:
# ChatCompletionRequest
max_tokens
=
getattr
(
request
,
"max_completion_tokens"
,
None
)
if
max_tokens
is
None
:
# ResponsesRequest
max_tokens
=
getattr
(
request
,
"max_output_tokens"
,
None
)
if
max_tokens
is
None
:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens
=
getattr
(
request
,
"max_tokens"
,
None
)
input_length
=
length_from_prompt_token_ids_or_embeds
(
prompt
.
get
(
"prompt_token_ids"
),
# type: ignore[arg-type]
prompt
.
get
(
"prompt_embeds"
),
# type: ignore[arg-type]
)
default_max_tokens
=
max_model_len
-
input_length
default_max_tokens
=
max_model_len
-
input_length
max_output_tokens
=
current_platform
.
get_max_output_tokens
(
input_length
)
max_output_tokens
=
current_platform
.
get_max_output_tokens
(
input_length
)
...
...
Prev
1
2
3
4
5
6
7
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment