Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
899a2db4
Commit
899a2db4
authored
Feb 05, 2026
by
zhuwenwen
Browse files
sync v0.15.1(ex fused_moe&models)
parent
78c1f9e5
Changes
72
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
169 additions
and
567 deletions
+169
-567
vllm/_custom_ops.py
vllm/_custom_ops.py
+1
-9
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+1
-2
vllm/config/compilation.py
vllm/config/compilation.py
+1
-2
vllm/config/speculative.py
vllm/config/speculative.py
+0
-12
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+5
-99
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+5
-44
vllm/distributed/device_communicators/cpu_communicator.py
vllm/distributed/device_communicators/cpu_communicator.py
+8
-44
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+6
-39
vllm/distributed/device_communicators/mnnvl_compat.py
vllm/distributed/device_communicators/mnnvl_compat.py
+1
-12
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+7
-43
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+0
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+2
-24
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-14
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+0
-1
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+3
-37
vllm/entrypoints/openai/chat_completion/serving.py
vllm/entrypoints/openai/chat_completion/serving.py
+45
-126
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+1
-12
vllm/entrypoints/openai/completion/serving.py
vllm/entrypoints/openai/completion/serving.py
+16
-4
vllm/entrypoints/openai/engine/protocol.py
vllm/entrypoints/openai/engine/protocol.py
+1
-5
vllm/entrypoints/openai/engine/serving.py
vllm/entrypoints/openai/engine/serving.py
+65
-37
No files found.
vllm/_custom_ops.py
View file @
899a2db4
...
@@ -3238,17 +3238,9 @@ def onednn_scaled_mm(
...
@@ -3238,17 +3238,9 @@ def onednn_scaled_mm(
bias
:
torch
.
Tensor
|
None
,
bias
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
torch
.
ops
.
_C
.
onednn_scaled_mm
(
torch
.
ops
.
_C
.
onednn_scaled_mm
(
output
,
output
,
x
,
input_scale
,
input_zp
,
input_zp_adj
,
bias
,
dnnl_handler
.
handler
x
,
input_scale
,
input_zp
,
input_zp_adj
,
bias
,
dnnl_handler
.
handler_tensor
,
)
)
return
output
def
cpu_attn_get_scheduler_metadata
(
def
cpu_attn_get_scheduler_metadata
(
num_reqs
:
int
,
num_reqs
:
int
,
...
...
vllm/compilation/decorators.py
View file @
899a2db4
...
@@ -32,7 +32,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
...
@@ -32,7 +32,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
.monitor
import
start_monitoring_torch_compile
from
.monitor
import
start_monitoring_torch_compile
from
vllm.forward_context
import
get_profilling
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
# Only added on nightly/2.10 so wrap
# Only added on nightly/2.10 so wrap
...
@@ -387,7 +386,7 @@ def _support_torch_compile(
...
@@ -387,7 +386,7 @@ def _support_torch_compile(
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
()
or
get_profilling
()
:
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
():
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
...
...
vllm/config/compilation.py
View file @
899a2db4
...
@@ -281,10 +281,9 @@ class DynamicShapesConfig:
...
@@ -281,10 +281,9 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
"""
assume_32_bit_indexing
:
bool
=
Fals
e
assume_32_bit_indexing
:
bool
=
Tru
e
"""
"""
whether all tensor sizes can use 32 bit indexing.
whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
"""
"""
def
compute_hash
(
self
)
->
str
:
def
compute_hash
(
self
)
->
str
:
...
...
vllm/config/speculative.py
View file @
899a2db4
...
@@ -34,7 +34,6 @@ MTPModelTypes = Literal[
...
@@ -34,7 +34,6 @@ MTPModelTypes = Literal[
"mimo_mtp"
,
"mimo_mtp"
,
"glm4_moe_mtp"
,
"glm4_moe_mtp"
,
"glm4_moe_lite_mtp"
,
"glm4_moe_lite_mtp"
,
"glm_ocr_mtp"
,
"ernie_mtp"
,
"ernie_mtp"
,
"exaone_moe_mtp"
,
"exaone_moe_mtp"
,
"qwen3_next_mtp"
,
"qwen3_next_mtp"
,
...
@@ -223,17 +222,6 @@ class SpeculativeConfig:
...
@@ -223,17 +222,6 @@ class SpeculativeConfig:
}
}
)
)
if
hf_config
.
architectures
[
0
]
==
"GlmOcrForConditionalGeneration"
:
hf_config
.
model_type
=
"glm_ocr_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
(
{
"num_hidden_layers"
:
0
,
"n_predict"
:
n_predict
,
"architectures"
:
[
"GlmOcrMTPModel"
],
}
)
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
hf_config
.
model_type
=
"ernie_mtp"
hf_config
.
model_type
=
"ernie_mtp"
if
hf_config
.
model_type
==
"ernie_mtp"
:
if
hf_config
.
model_type
==
"ernie_mtp"
:
...
...
vllm/distributed/device_communicators/all2all.py
View file @
899a2db4
...
@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
...
@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return
buffer
return
buffer
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase):
...
@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase):
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
extra_tensors
is
not
None
:
raise
NotImplementedError
(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size
=
self
.
tp_group
.
world_size
if
is_sequence_parallel
else
1
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
cu_tokens_across_sp_cpu
=
dp_metadata
.
cu_tokens_across_sp
(
sp_size
)
hidden_states
=
self
.
naive_multicast
(
hidden_states
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_weights
=
self
.
naive_multicast
(
topk_weights
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_ids
=
self
.
naive_multicast
(
topk_ids
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
...
@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def
__init__
(
self
,
cpu_group
):
def
__init__
(
self
,
cpu_group
):
super
().
__init__
(
cpu_group
)
super
().
__init__
(
cpu_group
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
...
@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
sizes
=
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
assert
sizes
is
not
None
dist_group
=
get_ep_group
()
if
is_sequence_parallel
else
get_dp_group
()
assert
sizes
[
dist_group
.
rank_in_group
]
==
hidden_states
.
shape
[
0
]
tensors_to_gather
=
[
hidden_states
,
topk_weights
,
topk_ids
]
if
extra_tensors
is
not
None
:
tensors_to_gather
.
extend
(
extra_tensors
)
gathered_tensors
=
dist_group
.
all_gatherv
(
tensors_to_gather
,
dim
=
0
,
sizes
=
sizes
,
)
hidden_states
=
gathered_tensors
[
0
]
topk_weights
=
gathered_tensors
[
1
]
topk_ids
=
gathered_tensors
[
2
]
if
extra_tensors
is
None
:
return
hidden_states
,
topk_weights
,
topk_ids
return
hidden_states
,
topk_weights
,
topk_ids
,
gathered_tensors
[
3
:]
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
...
@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx
.
AllToAll
.
internode
if
self
.
internode
else
pplx
.
AllToAll
.
intranode
,
pplx
.
AllToAll
.
internode
if
self
.
internode
else
pplx
.
AllToAll
.
intranode
,
)
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase):
...
@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
raise
NotImplementedError
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
...
@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def
get_handle
(
self
,
kwargs
):
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
...
@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
raise
NotImplementedError
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -693,4 +599,4 @@ class MoriAll2AllManager(All2AllManagerBase):
...
@@ -693,4 +599,4 @@ class MoriAll2AllManager(All2AllManagerBase):
handle
:
mori
.
ops
.
EpDispatchCombineOp
=
self
.
handle_cache
.
get_or_create
(
handle
:
mori
.
ops
.
EpDispatchCombineOp
=
self
.
handle_cache
.
get_or_create
(
mori_kwargs
,
self
.
_make_handle
mori_kwargs
,
self
.
_make_handle
)
)
return
handle
return
handle
\ No newline at end of file
vllm/distributed/device_communicators/base_device_communicator.py
View file @
899a2db4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
threading
import
threading
from
typing
import
Any
from
weakref
import
WeakValueDictionary
from
weakref
import
WeakValueDictionary
import
torch
import
torch
...
@@ -63,32 +64,13 @@ class All2AllManagerBase:
...
@@ -63,32 +64,13 @@ class All2AllManagerBase:
# and reuse it for the same config.
# and reuse it for the same config.
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
)
->
Any
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
# - raise a clear error if extra_tensors is not supported.
...
@@ -298,7 +280,7 @@ class DeviceCommunicatorBase:
...
@@ -298,7 +280,7 @@ class DeviceCommunicatorBase:
for
module
in
moe_modules
:
for
module
in
moe_modules
:
module
.
maybe_init_modular_kernel
()
module
.
maybe_init_modular_kernel
()
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -312,29 +294,8 @@ class DeviceCommunicatorBase:
...
@@ -312,29 +294,8 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
This is a no-op in the base class.
"""
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
router_logits
,
extra_tensors
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
topk_weights
,
topk_ids
,
extra_tensors
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -342,4 +303,4 @@ class DeviceCommunicatorBase:
...
@@ -342,4 +303,4 @@ class DeviceCommunicatorBase:
Combine the hidden states and router logits from the appropriate device.
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
This is a no-op in the base class.
"""
"""
return
hidden_states
return
hidden_states
\ No newline at end of file
vllm/distributed/device_communicators/cpu_communicator.py
View file @
899a2db4
...
@@ -130,65 +130,29 @@ class CpuCommunicator(DeviceCommunicatorBase):
...
@@ -130,65 +130,29 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
def
dispatch
_router_logits
(
def
dispatch
(
# type: ignore[override]
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
hidden_states
,
topk_weights
,
router_logits
,
topk_ids
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
hidden_states
,
is_sequence_parallel
is_sequence_parallel
,
)
)
return
hidden_states
class
_CPUSHMDistributed
:
class
_CPUSHMDistributed
:
...
@@ -286,4 +250,4 @@ class _CPUSHMDistributed:
...
@@ -286,4 +250,4 @@ class _CPUSHMDistributed:
tensor_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{}
tensor_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
key
,
size
,
t
in
zip
(
key_list
,
size_list
,
value_list
):
for
key
,
size
,
t
in
zip
(
key_list
,
size_list
,
value_list
):
tensor_dict
[
key
]
=
t
.
view
(
size
)
tensor_dict
[
key
]
=
t
.
view
(
size
)
return
tensor_dict
return
tensor_dict
\ No newline at end of file
vllm/distributed/device_communicators/cuda_communicator.py
View file @
899a2db4
...
@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return
output_list
return
output_list
def
dispatch
_router_logits
(
def
dispatch
(
# type: ignore[override]
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -332,52 +332,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -332,52 +332,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
hidden_states
,
topk_weights
,
router_logits
,
topk_ids
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
hidden_states
,
is_sequence_parallel
is_sequence_parallel
,
)
)
return
hidden_states
vllm/distributed/device_communicators/mnnvl_compat.py
View file @
899a2db4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
...
@@ -25,14 +23,5 @@ class CustomCommunicator(CommBackend):
...
@@ -25,14 +23,5 @@ class CustomCommunicator(CommBackend):
dist
.
all_gather_object
(
gathered
,
data
,
group
=
self
.
_group
)
dist
.
all_gather_object
(
gathered
,
data
,
group
=
self
.
_group
)
return
gathered
return
gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def
bcast
(
self
,
data
:
Any
,
root
:
int
)
->
Any
:
raise
NotImplementedError
def
barrier
(
self
)
->
None
:
raise
NotImplementedError
def
Split
(
self
,
color
:
int
,
key
:
int
)
->
"CustomCommunicator"
:
def
Split
(
self
,
color
:
int
,
key
:
int
)
->
"CustomCommunicator"
:
return
self
return
self
\ No newline at end of file
vllm/distributed/device_communicators/xpu_communicator.py
View file @
899a2db4
...
@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase):
...
@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase):
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
def
dispatch_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
def
dispatch
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
hidden_states
,
topk_weights
,
router_logits
,
topk_ids
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
extra_tensors
,
# type: ignore[call-arg]
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
combine
(
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
hidden_states
,
is_sequence_parallel
is_sequence_parallel
,
)
)
return
hidden_states
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
899a2db4
...
@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class
NixlConnector
(
KVConnectorBase_V1
):
class
NixlConnector
(
KVConnectorBase_V1
):
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
...
...
vllm/distributed/parallel_state.py
View file @
899a2db4
...
@@ -1003,7 +1003,7 @@ class GroupCoordinator:
...
@@ -1003,7 +1003,7 @@ class GroupCoordinator:
if
self
.
device_communicator
is
not
None
:
if
self
.
device_communicator
is
not
None
:
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
def
dispatch
_router_logits
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -1014,7 +1014,7 @@ class GroupCoordinator:
...
@@ -1014,7 +1014,7 @@ class GroupCoordinator:
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
):
if
self
.
device_communicator
is
not
None
:
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
_router_logits
(
return
self
.
device_communicator
.
dispatch
(
# type: ignore[call-arg]
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
is_sequence_parallel
,
is_sequence_parallel
,
...
@@ -1023,28 +1023,6 @@ class GroupCoordinator:
...
@@ -1023,28 +1023,6 @@ class GroupCoordinator:
else
:
else
:
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
):
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
,
)
else
:
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/engine/arg_utils.py
View file @
899a2db4
...
@@ -346,16 +346,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
...
@@ -346,16 +346,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
cached version.
cached version.
"""
"""
return
copy
.
deepcopy
(
_compute_kwargs
(
cls
))
return
copy
.
deepcopy
(
_compute_kwargs
(
cls
))
class
EnvironmentConfigError
(
Exception
):
pass
# def check_incompatible_config(env1: bool, env2: bool):
# if env1 is True and env2 is True:
# _s = "USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and USE_FUSED_RMS_QUANT must not be enabled simultaneously!\n\n"
# raise EnvironmentConfigError(_s)
@
dataclass
@
dataclass
class
EngineArgs
:
class
EngineArgs
:
...
@@ -1038,7 +1029,6 @@ class EngineArgs:
...
@@ -1038,7 +1029,6 @@ class EngineArgs:
)
)
lora_group
.
add_argument
(
"--default-mm-loras"
,
**
lora_kwargs
[
"default_mm_loras"
])
lora_group
.
add_argument
(
"--default-mm-loras"
,
**
lora_kwargs
[
"default_mm_loras"
])
# Observability arguments
# Observability arguments
observability_kwargs
=
get_kwargs
(
ObservabilityConfig
)
observability_kwargs
=
get_kwargs
(
ObservabilityConfig
)
observability_group
=
parser
.
add_argument_group
(
observability_group
=
parser
.
add_argument_group
(
...
@@ -1646,8 +1636,6 @@ class EngineArgs:
...
@@ -1646,8 +1636,6 @@ class EngineArgs:
target_model_config
=
model_config
,
target_model_config
=
model_config
,
target_parallel_config
=
parallel_config
,
target_parallel_config
=
parallel_config
,
)
)
# check_incompatible_config(envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT, envs.USE_FUSED_RMS_QUANT)
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
runner_type
=
model_config
.
runner_type
,
runner_type
=
model_config
.
runner_type
,
...
@@ -1789,7 +1777,6 @@ class EngineArgs:
...
@@ -1789,7 +1777,6 @@ class EngineArgs:
return
config
return
config
def
_check_feature_supported
(
self
,
model_config
:
ModelConfig
):
def
_check_feature_supported
(
self
,
model_config
:
ModelConfig
):
"""Raise an error if the feature is not supported."""
"""Raise an error if the feature is not supported."""
if
self
.
logits_processor_pattern
!=
EngineArgs
.
logits_processor_pattern
:
if
self
.
logits_processor_pattern
!=
EngineArgs
.
logits_processor_pattern
:
...
...
vllm/entrypoints/llm.py
View file @
899a2db4
...
@@ -78,7 +78,6 @@ from vllm.v1.engine import EngineCoreRequest
...
@@ -78,7 +78,6 @@ from vllm.v1.engine import EngineCoreRequest
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
vllm.v1.engine.llm_engine
import
LLMEngine
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.metrics.reader
import
Metric
from
vllm.v1.metrics.reader
import
Metric
...
...
vllm/entrypoints/openai/api_server.py
View file @
899a2db4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
asyncio
import
hashlib
import
hashlib
import
importlib
import
importlib
import
inspect
import
inspect
...
@@ -265,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None:
...
@@ -265,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None:
return
None
return
None
def
get_uvicorn_log_config
(
args
:
Namespace
)
->
dict
|
None
:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config
=
load_log_config
(
args
.
log_config_file
)
if
log_config
is
not
None
:
return
log_config
# If endpoints to filter are specified, create a config with the filter
if
args
.
disable_access_log_for_endpoints
:
from
vllm.logging_utils
import
create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths
=
[
p
.
strip
()
for
p
in
args
.
disable_access_log_for_endpoints
.
split
(
","
)
if
p
.
strip
()
]
return
create_uvicorn_log_config
(
excluded_paths
=
excluded_paths
,
log_level
=
args
.
uvicorn_log_level
,
)
return
None
class
AuthenticationMiddleware
:
class
AuthenticationMiddleware
:
"""
"""
Pure ASGI middleware that authenticates each request by checking
Pure ASGI middleware that authenticates each request by checking
...
@@ -964,8 +930,8 @@ async def run_server_worker(
...
@@ -964,8 +930,8 @@ async def run_server_worker(
if
args
.
reasoning_parser_plugin
and
len
(
args
.
reasoning_parser_plugin
)
>
3
:
if
args
.
reasoning_parser_plugin
and
len
(
args
.
reasoning_parser_plugin
)
>
3
:
ReasoningParserManager
.
import_reasoning_parser
(
args
.
reasoning_parser_plugin
)
ReasoningParserManager
.
import_reasoning_parser
(
args
.
reasoning_parser_plugin
)
#
Get uvicorn log config (from file or with endpoint filter)
#
Load logging config for uvicorn if specified
log_config
=
get_uvicorn
_log_config
(
args
)
log_config
=
load
_log_config
(
args
.
log_config_file
)
if
log_config
is
not
None
:
if
log_config
is
not
None
:
uvicorn_kwargs
[
"log_config"
]
=
log_config
uvicorn_kwargs
[
"log_config"
]
=
log_config
...
@@ -1022,4 +988,4 @@ if __name__ == "__main__":
...
@@ -1022,4 +988,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
validate_parsed_serve_args
(
args
)
validate_parsed_serve_args
(
args
)
uvloop
.
run
(
run_server
(
args
))
uvloop
.
run
(
run_server
(
args
))
\ No newline at end of file
vllm/entrypoints/openai/chat_completion/serving.py
View file @
899a2db4
...
@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import (
...
@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage
,
DeltaMessage
,
DeltaToolCall
,
DeltaToolCall
,
ErrorResponse
,
ErrorResponse
,
FunctionCall
,
PromptTokenUsageInfo
,
PromptTokenUsageInfo
,
RequestResponseMetadata
,
RequestResponseMetadata
,
ToolCall
,
ToolCall
,
...
@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
...
@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from
vllm.entrypoints.openai.utils
import
maybe_filter_parallel_tool_calls
from
vllm.entrypoints.openai.utils
import
maybe_filter_parallel_tool_calls
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.parse
import
get_prompt_components
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
...
@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing):
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
self
.
enable_force_include_usage
=
enable_force_include_usage
self
.
enable_force_include_usage
=
enable_force_include_usage
self
.
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
self
.
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
if
self
.
model_config
.
hf_config
.
model_type
==
"kimi_k2"
:
self
.
tool_call_id_type
=
"kimi_k2"
else
:
self
.
tool_call_id_type
=
"random"
self
.
use_harmony
=
self
.
model_config
.
hf_config
.
model_type
==
"gpt_oss"
self
.
use_harmony
=
self
.
model_config
.
hf_config
.
model_type
==
"gpt_oss"
if
self
.
use_harmony
:
if
self
.
use_harmony
:
if
"stop_token_ids"
not
in
self
.
default_sampling_params
:
if
"stop_token_ids"
not
in
self
.
default_sampling_params
:
...
@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions
()
get_stop_tokens_for_assistant_actions
()
)
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides
=
getattr
(
self
.
model_config
,
"hf_overrides"
,
None
)
if
self
.
model_config
.
hf_text_config
.
model_type
==
"kimi_k2"
or
(
isinstance
(
hf_overrides
,
dict
)
and
hf_overrides
.
get
(
"model_type"
)
==
"kimi_k2"
):
self
.
tool_call_id_type
=
"kimi_k2"
else
:
self
.
tool_call_id_type
=
"random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
# Responses API instead.
...
@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
# type: ignore[arg-type]
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
# type: ignore[arg-type]
truncate_tool_call_ids
(
request
)
validate_request_params
(
request
)
validate_request_params
(
request
)
# Check if tool parsing is unavailable (common condition)
# Check if tool parsing is unavailable (common condition)
...
@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
_
,
_
=
self
.
_get_prompt_components
(
engine_prompt
)
# If we are creating sub requests for multiple prompts, ensure that they
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
# have unique request ids.
sub_request_id
=
(
sub_request_id
=
(
request_id
if
len
(
engine_prompts
)
==
1
else
f
"
{
request_id
}
_
{
i
}
"
request_id
if
len
(
engine_prompts
)
==
1
else
f
"
{
request_id
}
_
{
i
}
"
)
)
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
request
=
request
,
request
=
request
,
prompt
=
engine_prompt
,
input_length
=
len
(
engine_prompt
[
"prompt_token_ids"
])
,
default_sampling_params
=
self
.
default_sampling_params
,
default_sampling_params
=
self
.
default_sampling_params
,
)
)
...
@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
# Streaming response
tokenizer
=
self
.
renderer
.
tokenizer
tokenizer
=
self
.
renderer
.
tokenizer
assert
tokenizer
is
not
None
if
request
.
stream
:
if
request
.
stream
:
return
self
.
chat_completion_stream_generator
(
return
self
.
chat_completion_stream_generator
(
...
@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
model_name
:
str
,
model_name
:
str
,
conversation
:
list
[
ConversationMessage
],
conversation
:
list
[
ConversationMessage
],
tokenizer
:
TokenizerLike
,
tokenizer
:
TokenizerLike
|
None
,
request_metadata
:
RequestResponseMetadata
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
from
vllm.tokenizers.mistral
import
MistralTokenizer
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
chunk_object_type
:
Final
=
"chat.completion.chunk"
chunk_object_type
:
Final
=
"chat.completion.chunk"
first_iteration
=
True
first_iteration
=
True
...
@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing):
)
)
reasoning_parser
=
self
.
reasoning_parser
(
reasoning_parser
=
self
.
reasoning_parser
(
tokenizer
,
tokenizer
,
chat_template_kwargs
=
chat_template_kwargs
or
{}
,
# type: ignore[call-arg]
chat_template_kwargs
=
chat_template_kwargs
,
# type: ignore[call-arg]
)
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logger
.
exception
(
"Error in reasoning parser creation."
)
logger
.
exception
(
"Error in reasoning parser creation."
)
...
@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing):
index
=
i
,
index
=
i
,
)
)
else
:
else
:
# Generate ID based on tokenizer type
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_id
=
MistralToolCall
.
generate_random_id
()
else
:
tool_call_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tool_choice_function_name
,
idx
=
history_tool_call_cnt
,
)
delta_tool_call
=
DeltaToolCall
(
delta_tool_call
=
DeltaToolCall
(
id
=
tool_call_id
,
id
=
make_
tool_call_id
()
,
type
=
"function"
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
function
=
DeltaFunctionCall
(
name
=
tool_choice_function_name
,
name
=
tool_choice_function_name
,
...
@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing):
request_id
:
str
,
request_id
:
str
,
model_name
:
str
,
model_name
:
str
,
conversation
:
list
[
ConversationMessage
],
conversation
:
list
[
ConversationMessage
],
tokenizer
:
TokenizerLike
,
tokenizer
:
TokenizerLike
|
None
,
request_metadata
:
RequestResponseMetadata
,
request_metadata
:
RequestResponseMetadata
,
)
->
ErrorResponse
|
ChatCompletionResponse
:
)
->
ErrorResponse
|
ChatCompletionResponse
:
from
vllm.tokenizers.mistral
import
MistralTokenizer
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
final_res
:
RequestOutput
|
None
=
None
final_res
:
RequestOutput
|
None
=
None
...
@@ -1543,85 +1524,39 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1543,85 +1524,39 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class
=
(
tool_call_class
=
(
MistralToolCall
if
isinstance
(
tokenizer
,
MistralTokenizer
)
else
ToolCall
MistralToolCall
if
isinstance
(
tokenizer
,
MistralTokenizer
)
else
ToolCall
)
)
if
self
.
use_harmony
:
if
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
(
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
,
tool_calls
=
tool_calls
if
tool_calls
else
[],
)
elif
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
(
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
)
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
)
and
request
.
tool_choice
!=
"required"
and
request
.
tool_choice
!=
"required"
):
):
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
)
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
)
# if the request uses tools and specified a tool choice
elif
(
elif
(
request
.
tool_choice
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
):
):
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
tool_call_class_items
=
[]
for
idx
,
tc
in
enumerate
(
tool_calls
):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if
tc
.
id
:
tool_call_class_items
.
append
(
tool_call_class
(
id
=
tc
.
id
,
function
=
tc
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_class_items
.
append
(
tool_call_class
(
function
=
tc
))
else
:
generated_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tc
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
)
tool_call_class_items
.
append
(
tool_call_class
(
id
=
generated_id
,
function
=
tc
)
)
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
message
=
ChatMessage
(
role
=
role
,
role
=
role
,
reasoning
=
reasoning
,
reasoning
=
reasoning
,
content
=
""
,
content
=
""
,
tool_calls
=
tool_call_class
_items
,
tool_calls
=
[
tool_call_class
(
function
=
tc
)
for
tc
in
tool_calls
]
,
)
)
elif
request
.
tool_choice
and
request
.
tool_choice
==
"required"
:
elif
request
.
tool_choice
and
request
.
tool_choice
==
"required"
:
tool_call_class_items
=
[]
tool_call_class_items
=
[]
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
for
idx
,
tool_call
in
enumerate
(
tool_calls
):
for
tool_call
in
tool_calls
:
# Use native ID if available,
tool_call_class_items
.
append
(
# otherwise generate ID with correct id_type
tool_call_class
(
if
tool_call
.
id
:
id
=
make_tool_call_id
(
tool_call_class_items
.
append
(
tool_call_class
(
id
=
tool_call
.
id
,
function
=
tool_call
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_class_items
.
append
(
tool_call_class
(
function
=
tool_call
)
)
else
:
generated_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
id_type
=
self
.
tool_call_id_type
,
func_name
=
tool_call
.
name
,
func_name
=
tool_call
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
idx
=
history_tool_call_cnt
,
)
)
,
tool_call_class_items
.
append
(
function
=
tool_call
,
tool_call_class
(
id
=
generated_id
,
function
=
tool_call
)
)
)
)
history_tool_call_cnt
+=
1
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
message
=
ChatMessage
(
role
=
role
,
role
=
role
,
...
@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls
# call. The same is not true for named function calls
auto_tools_called
=
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
auto_tools_called
=
tool_calls
is
not
None
and
len
(
tool_calls
)
>
0
if
tool_calls
:
if
tool_calls
:
tool_call_items
=
[]
for
idx
,
tc
in
enumerate
(
tool_calls
):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if
tc
.
id
:
tool_call_items
.
append
(
tool_call_class
(
id
=
tc
.
id
,
function
=
tc
)
)
else
:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tool_call_items
.
append
(
tool_call_class
(
function
=
tc
))
else
:
generated_id
=
make_tool_call_id
(
id_type
=
self
.
tool_call_id_type
,
func_name
=
tc
.
name
,
idx
=
history_tool_call_cnt
+
idx
,
)
tool_call_items
.
append
(
tool_call_class
(
id
=
generated_id
,
function
=
tc
)
)
history_tool_call_cnt
+=
1
message
=
ChatMessage
(
message
=
ChatMessage
(
role
=
role
,
role
=
role
,
reasoning
=
reasoning
,
reasoning
=
reasoning
,
content
=
content
,
content
=
content
,
tool_calls
=
tool_call_items
,
tool_calls
=
[
ToolCall
(
function
=
tc
,
type
=
"function"
,
)
for
tc
in
tool_calls
],
)
)
else
:
else
:
...
@@ -1784,11 +1701,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1784,11 +1701,13 @@ class OpenAIServingChat(OpenAIServing):
elif
choice
.
message
.
tool_calls
:
elif
choice
.
message
.
tool_calls
:
# For tool calls, log the function name and arguments
# For tool calls, log the function name and arguments
tool_call_descriptions
=
[]
tool_call_descriptions
=
[]
for
tc
in
choice
.
message
.
tool_calls
:
# type: ignore
for
tc
in
choice
.
message
.
tool_calls
:
function_call
:
FunctionCall
=
tc
.
function
# type: ignore
if
hasattr
(
tc
.
function
,
"name"
)
and
hasattr
(
tool_call_descriptions
.
append
(
tc
.
function
,
"arguments"
f
"
{
function_call
.
name
}
(
{
function_call
.
arguments
}
)"
):
)
tool_call_descriptions
.
append
(
f
"
{
tc
.
function
.
name
}
(
{
tc
.
function
.
arguments
}
)"
)
tool_calls_str
=
", "
.
join
(
tool_call_descriptions
)
tool_calls_str
=
", "
.
join
(
tool_call_descriptions
)
output_text
=
f
"[tool_calls:
{
tool_calls_str
}
]"
output_text
=
f
"[tool_calls:
{
tool_calls_str
}
]"
...
@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
# type: ignore[arg-type]
maybe_serialize_tool_calls
(
request
)
# Add system message.
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
# NOTE: In Chat Completion API, browsing is enabled by default
...
@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message.
# Add developer message.
if
request
.
tools
:
if
request
.
tools
:
dev_msg
=
get_developer_message
(
dev_msg
=
get_developer_message
(
tools
=
request
.
tools
if
should_include_tools
else
None
# type: ignore[arg-type]
tools
=
request
.
tools
if
should_include_tools
else
None
)
)
messages
.
append
(
dev_msg
)
messages
.
append
(
dev_msg
)
...
@@ -2009,4 +1928,4 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -2009,4 +1928,4 @@ class OpenAIServingChat(OpenAIServing):
if
request
.
cache_salt
is
not
None
:
if
request
.
cache_salt
is
not
None
:
engine_prompt
[
"cache_salt"
]
=
request
.
cache_salt
engine_prompt
[
"cache_salt"
]
=
request
.
cache_salt
return
messages
,
[
engine_prompt
]
return
messages
,
[
engine_prompt
]
\ No newline at end of file
vllm/entrypoints/openai/cli_args.py
View file @
899a2db4
...
@@ -85,12 +85,6 @@ class FrontendArgs:
...
@@ -85,12 +85,6 @@ class FrontendArgs:
"""Log level for uvicorn."""
"""Log level for uvicorn."""
disable_uvicorn_access_log
:
bool
=
False
disable_uvicorn_access_log
:
bool
=
False
"""Disable uvicorn access log."""
"""Disable uvicorn access log."""
disable_access_log_for_endpoints
:
str
|
None
=
None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials
:
bool
=
False
allow_credentials
:
bool
=
False
"""Allow credentials."""
"""Allow credentials."""
allowed_origins
:
list
[
str
]
=
field
(
default_factory
=
lambda
:
[
"*"
])
allowed_origins
:
list
[
str
]
=
field
(
default_factory
=
lambda
:
[
"*"
])
...
@@ -250,11 +244,6 @@ class FrontendArgs:
...
@@ -250,11 +244,6 @@ class FrontendArgs:
del
frontend_kwargs
[
"middleware"
][
"nargs"
]
del
frontend_kwargs
[
"middleware"
][
"nargs"
]
frontend_kwargs
[
"middleware"
][
"default"
]
=
[]
frontend_kwargs
[
"middleware"
][
"default"
]
=
[]
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if
"nargs"
in
frontend_kwargs
[
"disable_access_log_for_endpoints"
]:
del
frontend_kwargs
[
"disable_access_log_for_endpoints"
][
"nargs"
]
# Special case: Tool call parser shows built-in options.
# Special case: Tool call parser shows built-in options.
valid_tool_parsers
=
list
(
ToolParserManager
.
list_registered
())
valid_tool_parsers
=
list
(
ToolParserManager
.
list_registered
())
parsers_str
=
","
.
join
(
valid_tool_parsers
)
parsers_str
=
","
.
join
(
valid_tool_parsers
)
...
@@ -332,4 +321,4 @@ def create_parser_for_docs() -> FlexibleArgumentParser:
...
@@ -332,4 +321,4 @@ def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs
=
FlexibleArgumentParser
(
parser_for_docs
=
FlexibleArgumentParser
(
prog
=
"-m vllm.entrypoints.openai.api_server"
prog
=
"-m vllm.entrypoints.openai.api_server"
)
)
return
make_arg_parser
(
parser_for_docs
)
return
make_arg_parser
(
parser_for_docs
)
\ No newline at end of file
vllm/entrypoints/openai/completion/serving.py
View file @
899a2db4
...
@@ -36,7 +36,6 @@ from vllm.entrypoints.renderer import RenderConfig
...
@@ -36,7 +36,6 @@ from vllm.entrypoints.renderer import RenderConfig
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
EmbedsPrompt
,
TokensPrompt
,
is_embeds_prompt
from
vllm.inputs.data
import
EmbedsPrompt
,
TokensPrompt
,
is_embeds_prompt
from
vllm.inputs.parse
import
get_prompt_components
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.logprobs
import
Logprob
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
...
@@ -163,12 +162,25 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -163,12 +162,25 @@ class OpenAIServingCompletion(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
prompt_token_ids
,
prompt_embeds
=
(
self
.
_get_prompt_components
(
engine_prompt
)
)
input_length
=
None
if
prompt_token_ids
is
not
None
:
input_length
=
len
(
prompt_token_ids
)
elif
prompt_embeds
is
not
None
:
input_length
=
len
(
prompt_embeds
)
else
:
raise
NotImplementedError
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
request
=
request
,
request
=
request
,
prompt
=
engine_prompt
,
input_length
=
input_length
,
default_sampling_params
=
self
.
default_sampling_params
,
default_sampling_params
=
self
.
default_sampling_params
,
)
)
...
@@ -731,4 +743,4 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -731,4 +743,4 @@ class OpenAIServingCompletion(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
cache_salt
=
request
.
cache_salt
,
cache_salt
=
request
.
cache_salt
,
needs_detokenization
=
bool
(
request
.
echo
and
not
request
.
return_token_ids
),
needs_detokenization
=
bool
(
request
.
echo
and
not
request
.
return_token_ids
),
)
)
\ No newline at end of file
vllm/entrypoints/openai/engine/protocol.py
View file @
899a2db4
...
@@ -218,10 +218,6 @@ def get_logits_processors(
...
@@ -218,10 +218,6 @@ def get_logits_processors(
class
FunctionCall
(
OpenAIBaseModel
):
class
FunctionCall
(
OpenAIBaseModel
):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id
:
str
|
None
=
Field
(
default
=
None
,
exclude
=
True
)
name
:
str
name
:
str
arguments
:
str
arguments
:
str
...
@@ -319,4 +315,4 @@ class GenerateRequest(BaseModel):
...
@@ -319,4 +315,4 @@ class GenerateRequest(BaseModel):
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
Field
(
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
Field
(
default
=
None
,
default
=
None
,
description
=
"KVTransfer parameters used for disaggregated serving."
,
description
=
"KVTransfer parameters used for disaggregated serving."
,
)
)
\ No newline at end of file
vllm/entrypoints/openai/engine/serving.py
View file @
899a2db4
...
@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import (
...
@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import (
from
vllm.entrypoints.pooling.classify.protocol
import
(
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
ClassificationChatRequest
,
ClassificationCompletionRequest
,
ClassificationCompletionRequest
,
ClassificationRequest
,
ClassificationResponse
,
ClassificationResponse
,
)
)
from
vllm.entrypoints.pooling.embed.protocol
import
(
from
vllm.entrypoints.pooling.embed.protocol
import
(
EmbeddingBytesResponse
,
EmbeddingChatRequest
,
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
EmbeddingCompletionRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponse
,
)
)
from
vllm.entrypoints.pooling.pooling.protocol
import
(
from
vllm.entrypoints.pooling.pooling.protocol
import
(
...
@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
...
@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest
,
TokenizeCompletionRequest
,
TokenizeResponse
,
TokenizeResponse
,
)
)
from
vllm.entrypoints.utils
import
(
from
vllm.entrypoints.utils
import
_validate_truncation_size
,
sanitize_message
_validate_truncation_size
,
get_max_tokens
,
sanitize_message
,
)
from
vllm.exceptions
import
VLLMValidationError
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.parse
import
(
from
vllm.inputs.parse
import
(
PromptComponents
,
get_prompt_components
,
get_prompt_components
,
is_explicit_encoder_decoder_prompt
,
is_explicit_encoder_decoder_prompt
,
)
)
...
@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = (
...
@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = (
CompletionResponse
CompletionResponse
|
ChatCompletionResponse
|
ChatCompletionResponse
|
EmbeddingResponse
|
EmbeddingResponse
|
EmbeddingBytesResponse
|
TranscriptionResponse
|
TranscriptionResponse
|
TokenizeResponse
|
TokenizeResponse
|
PoolingResponse
|
PoolingResponse
...
@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
...
@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@
dataclass
(
kw_only
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ServeContext
(
Generic
[
RequestT
]):
class
RequestProcessingMixin
:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts
:
list
[
TokensPrompt
]
|
None
=
field
(
default_factory
=
list
)
@
dataclass
(
kw_only
=
True
)
class
ResponseGenerationMixin
:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator
:
(
AsyncGenerator
[
tuple
[
int
,
RequestOutput
|
PoolingRequestOutput
],
None
]
|
None
)
=
None
final_res_batch
:
list
[
RequestOutput
|
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ServeContext
(
RequestProcessingMixin
,
ResponseGenerationMixin
,
Generic
[
RequestT
]):
request
:
RequestT
request
:
RequestT
raw_request
:
Request
|
None
=
None
raw_request
:
Request
|
None
=
None
model_name
:
str
model_name
:
str
request_id
:
str
request_id
:
str
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
LoRARequest
|
None
=
None
lora_request
:
LoRARequest
|
None
=
None
engine_prompts
:
list
[
TokensPrompt
]
|
None
=
None
result_generator
:
AsyncGenerator
[
tuple
[
int
,
PoolingRequestOutput
],
None
]
|
None
=
(
None
)
final_res_batch
:
list
[
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ClassificationServeContext
(
ServeContext
[
ClassificationRequest
]):
pass
@
dataclass
(
kw_only
=
True
)
class
EmbeddingServeContext
(
ServeContext
[
EmbeddingRequest
]):
chat_template
:
str
|
None
=
None
chat_template_content_format
:
ChatTemplateContentFormatOption
class
OpenAIServing
:
class
OpenAIServing
:
...
@@ -578,7 +605,10 @@ class OpenAIServing:
...
@@ -578,7 +605,10 @@ class OpenAIServing:
self
,
self
,
ctx
:
ServeContext
,
ctx
:
ServeContext
,
)
->
AnyResponse
|
ErrorResponse
:
)
->
AnyResponse
|
ErrorResponse
:
async
for
response
in
self
.
_pipeline
(
ctx
):
generation
:
AsyncGenerator
[
AnyResponse
|
ErrorResponse
,
None
]
generation
=
self
.
_pipeline
(
ctx
)
async
for
response
in
generation
:
return
response
return
response
return
self
.
create_error_response
(
"No response yielded from pipeline"
)
return
self
.
create_error_response
(
"No response yielded from pipeline"
)
...
@@ -637,7 +667,9 @@ class OpenAIServing:
...
@@ -637,7 +667,9 @@ class OpenAIServing:
ctx
:
ServeContext
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
)
->
ErrorResponse
|
None
:
"""Schedule the request and get the result generator."""
"""Schedule the request and get the result generator."""
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]
]
=
[]
try
:
try
:
trace_headers
=
(
trace_headers
=
(
...
@@ -691,7 +723,7 @@ class OpenAIServing:
...
@@ -691,7 +723,7 @@ class OpenAIServing:
return
self
.
create_error_response
(
"Engine prompts not available"
)
return
self
.
create_error_response
(
"Engine prompts not available"
)
num_prompts
=
len
(
ctx
.
engine_prompts
)
num_prompts
=
len
(
ctx
.
engine_prompts
)
final_res_batch
:
list
[
PoolingRequestOutput
|
None
]
final_res_batch
:
list
[
RequestOutput
|
PoolingRequestOutput
|
None
]
final_res_batch
=
[
None
]
*
num_prompts
final_res_batch
=
[
None
]
*
num_prompts
if
ctx
.
result_generator
is
None
:
if
ctx
.
result_generator
is
None
:
...
@@ -949,7 +981,6 @@ class OpenAIServing:
...
@@ -949,7 +981,6 @@ class OpenAIServing:
max_length
=
truncate_prompt_tokens
,
max_length
=
truncate_prompt_tokens
,
)
)
input_ids
=
encoded
.
input_ids
input_ids
=
encoded
.
input_ids
input_text
=
prompt
input_text
=
prompt
...
@@ -973,14 +1004,14 @@ class OpenAIServing:
...
@@ -973,14 +1004,14 @@ class OpenAIServing:
if
tokenizer
is
None
:
if
tokenizer
is
None
:
input_text
=
""
input_text
=
""
else
:
else
:
async_tokenizer
=
self
.
_get_async_tokenizer
(
tokenizer
)
async_tokenizer
=
self
.
_get_async_tokenizer
(
tokenizer
)
input_text
=
await
async_tokenizer
.
decode
(
input_ids
)
input_text
=
await
async_tokenizer
.
decode
(
input_ids
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
def
_validate_input
(
def
_validate_input
(
self
,
self
,
request
:
objec
t
,
request
:
AnyReques
t
,
input_ids
:
list
[
int
],
input_ids
:
list
[
int
],
input_text
:
str
,
input_text
:
str
,
)
->
TokensPrompt
:
)
->
TokensPrompt
:
...
@@ -1291,7 +1322,7 @@ class OpenAIServing:
...
@@ -1291,7 +1322,7 @@ class OpenAIServing:
priority
:
int
=
0
,
priority
:
int
=
0
,
**
kwargs
,
**
kwargs
,
):
):
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
_
,
_
=
self
.
_
get_prompt_components
(
engine_prompt
)
orig_priority
=
priority
orig_priority
=
priority
sub_request
=
0
sub_request
=
0
...
@@ -1342,12 +1373,10 @@ class OpenAIServing:
...
@@ -1342,12 +1373,10 @@ class OpenAIServing:
# yield context
# yield context
# Create inputs for the next turn.
# Create inputs for the next turn.
# Render the next prompt token ids
and update sampling_params
.
# Render the next prompt token ids.
if
isinstance
(
context
,
(
HarmonyContext
,
StreamingHarmonyContext
)):
if
isinstance
(
context
,
(
HarmonyContext
,
StreamingHarmonyContext
)):
token_ids
=
context
.
render_for_completion
()
prompt_token_ids
=
context
.
render_for_completion
()
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
token_ids
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
)
sampling_params
.
max_tokens
=
self
.
max_model_len
-
len
(
token_ids
)
elif
isinstance
(
context
,
ParsableContext
):
elif
isinstance
(
context
,
ParsableContext
):
engine_prompts
=
await
self
.
_render_next_turn
(
engine_prompts
=
await
self
.
_render_next_turn
(
context
.
request
,
context
.
request
,
...
@@ -1359,19 +1388,19 @@ class OpenAIServing:
...
@@ -1359,19 +1388,19 @@ class OpenAIServing:
context
.
chat_template_content_format
,
context
.
chat_template_content_format
,
)
)
engine_prompt
=
engine_prompts
[
0
]
engine_prompt
=
engine_prompts
[
0
]
prompt_text
,
_
,
_
=
get_prompt_components
(
engine_prompt
)
prompt_text
,
_
,
_
=
self
.
_get_prompt_components
(
engine_prompt
)
sampling_params
.
max_tokens
=
get_max_tokens
(
self
.
max_model_len
,
context
.
request
,
engine_prompt
,
self
.
default_sampling_params
,
# type: ignore
)
# Update the sampling params.
sampling_params
.
max_tokens
=
self
.
max_model_len
-
len
(
engine_prompt
[
"prompt_token_ids"
]
)
# OPTIMIZATION
# OPTIMIZATION
priority
=
orig_priority
-
1
priority
=
orig_priority
-
1
sub_request
+=
1
sub_request
+=
1
def
_get_prompt_components
(
self
,
prompt
:
PromptType
)
->
PromptComponents
:
return
get_prompt_components
(
prompt
)
def
_log_inputs
(
def
_log_inputs
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
@@ -1382,7 +1411,7 @@ class OpenAIServing:
...
@@ -1382,7 +1411,7 @@ class OpenAIServing:
if
self
.
request_logger
is
None
:
if
self
.
request_logger
is
None
:
return
return
prompt
,
prompt_token_ids
,
prompt_embeds
=
get_prompt_components
(
inputs
)
prompt
,
prompt_token_ids
,
prompt_embeds
=
self
.
_
get_prompt_components
(
inputs
)
self
.
request_logger
.
log_inputs
(
self
.
request_logger
.
log_inputs
(
request_id
,
request_id
,
...
@@ -1496,7 +1525,6 @@ class OpenAIServing:
...
@@ -1496,7 +1525,6 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls.
# extract_tool_calls() returns a list of tool calls.
function_calls
.
extend
(
function_calls
.
extend
(
FunctionCall
(
FunctionCall
(
id
=
tool_call
.
id
,
name
=
tool_call
.
function
.
name
,
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
arguments
=
tool_call
.
function
.
arguments
,
)
)
...
@@ -1549,4 +1577,4 @@ def clamp_prompt_logprobs(
...
@@ -1549,4 +1577,4 @@ def clamp_prompt_logprobs(
for
logprob_values
in
logprob_dict
.
values
():
for
logprob_values
in
logprob_dict
.
values
():
if
logprob_values
.
logprob
==
float
(
"-inf"
):
if
logprob_values
.
logprob
==
float
(
"-inf"
):
logprob_values
.
logprob
=
-
9999.0
logprob_values
.
logprob
=
-
9999.0
return
prompt_logprobs
return
prompt_logprobs
\ No newline at end of file
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment