Commit 899a2db4 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1(ex fused_moe&models)

parent 78c1f9e5
...@@ -3238,17 +3238,9 @@ def onednn_scaled_mm( ...@@ -3238,17 +3238,9 @@ def onednn_scaled_mm(
bias: torch.Tensor | None, bias: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
torch.ops._C.onednn_scaled_mm( torch.ops._C.onednn_scaled_mm(
output, output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler
x,
input_scale,
input_zp,
input_zp_adj,
bias,
dnnl_handler.handler_tensor,
) )
return output
def cpu_attn_get_scheduler_metadata( def cpu_attn_get_scheduler_metadata(
num_reqs: int, num_reqs: int,
......
...@@ -32,7 +32,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname ...@@ -32,7 +32,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from .monitor import start_monitoring_torch_compile from .monitor import start_monitoring_torch_compile
from vllm.forward_context import get_profilling
if TYPE_CHECKING: if TYPE_CHECKING:
# Only added on nightly/2.10 so wrap # Only added on nightly/2.10 so wrap
...@@ -387,7 +386,7 @@ def _support_torch_compile( ...@@ -387,7 +386,7 @@ def _support_torch_compile(
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't # e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside. # need to compile the model inside.
if self.do_not_compile or torch.compiler.is_compiling() or get_profilling(): if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
# If skip_compiled is set, bypass compiled model call. This is used e.g. for # If skip_compiled is set, bypass compiled model call. This is used e.g. for
......
...@@ -281,10 +281,9 @@ class DynamicShapesConfig: ...@@ -281,10 +281,9 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239. until this change picked up https://github.com/pytorch/pytorch/pull/169239.
""" """
assume_32_bit_indexing: bool = False assume_32_bit_indexing: bool = True
""" """
whether all tensor sizes can use 32 bit indexing. whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
......
...@@ -34,7 +34,6 @@ MTPModelTypes = Literal[ ...@@ -34,7 +34,6 @@ MTPModelTypes = Literal[
"mimo_mtp", "mimo_mtp",
"glm4_moe_mtp", "glm4_moe_mtp",
"glm4_moe_lite_mtp", "glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp", "ernie_mtp",
"exaone_moe_mtp", "exaone_moe_mtp",
"qwen3_next_mtp", "qwen3_next_mtp",
...@@ -223,17 +222,6 @@ class SpeculativeConfig: ...@@ -223,17 +222,6 @@ class SpeculativeConfig:
} }
) )
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe": if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp" hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp": if hf_config.model_type == "ernie_mtp":
......
...@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer return buffer
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group): def __init__(self, cpu_group):
super().__init__(cpu_group) super().__init__(cpu_group)
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1] return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
) )
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -693,4 +599,4 @@ class MoriAll2AllManager(All2AllManagerBase): ...@@ -693,4 +599,4 @@ class MoriAll2AllManager(All2AllManagerBase):
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create( handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
mori_kwargs, self._make_handle mori_kwargs, self._make_handle
) )
return handle return handle
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import torch import torch
...@@ -63,32 +64,13 @@ class All2AllManagerBase: ...@@ -63,32 +64,13 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> ( ) -> Any:
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either: # Subclasses should either:
# - implement handling for extra_tensors, or # - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported. # - raise a clear error if extra_tensors is not supported.
...@@ -298,7 +280,7 @@ class DeviceCommunicatorBase: ...@@ -298,7 +280,7 @@ class DeviceCommunicatorBase:
for module in moe_modules: for module in moe_modules:
module.maybe_init_modular_kernel() module.maybe_init_modular_kernel()
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -312,29 +294,8 @@ class DeviceCommunicatorBase: ...@@ -312,29 +294,8 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -342,4 +303,4 @@ class DeviceCommunicatorBase: ...@@ -342,4 +303,4 @@ class DeviceCommunicatorBase:
Combine the hidden states and router logits from the appropriate device. Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
return hidden_states return hidden_states
\ No newline at end of file
...@@ -130,65 +130,29 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -130,65 +130,29 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]: ) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src) return self.dist_module.recv_tensor_dict(src)
def dispatch_router_logits( def dispatch( # type: ignore[override]
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> ( ) -> tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, hidden_states,
topk_weights, router_logits,
topk_ids,
is_sequence_parallel, is_sequence_parallel,
extra_tensors=extra_tensors, extra_tensors, # type: ignore[call-arg]
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.combine( hidden_states = self.all2all_manager.combine(
hidden_states, hidden_states, is_sequence_parallel
is_sequence_parallel,
) )
return hidden_states
class _CPUSHMDistributed: class _CPUSHMDistributed:
...@@ -286,4 +250,4 @@ class _CPUSHMDistributed: ...@@ -286,4 +250,4 @@ class _CPUSHMDistributed:
tensor_dict: dict[str, torch.Tensor] = {} tensor_dict: dict[str, torch.Tensor] = {}
for key, size, t in zip(key_list, size_list, value_list): for key, size, t in zip(key_list, size_list, value_list):
tensor_dict[key] = t.view(size) tensor_dict[key] = t.view(size)
return tensor_dict return tensor_dict
\ No newline at end of file
...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch_router_logits( def dispatch( # type: ignore[override]
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -332,52 +332,19 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -332,52 +332,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor] tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, hidden_states,
topk_weights, router_logits,
topk_ids,
is_sequence_parallel, is_sequence_parallel,
extra_tensors=extra_tensors, extra_tensors, # type: ignore[call-arg]
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.combine( hidden_states = self.all2all_manager.combine(
hidden_states, hidden_states, is_sequence_parallel
is_sequence_parallel,
) )
return hidden_states
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend from flashinfer.comm.mnnvl import CommBackend as CommBackend
...@@ -25,14 +23,5 @@ class CustomCommunicator(CommBackend): ...@@ -25,14 +23,5 @@ class CustomCommunicator(CommBackend):
dist.all_gather_object(gathered, data, group=self._group) dist.all_gather_object(gathered, data, group=self._group)
return gathered return gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
def barrier(self) -> None:
raise NotImplementedError
def Split(self, color: int, key: int) -> "CustomCommunicator": def Split(self, color: int, key: int) -> "CustomCommunicator":
return self return self
\ No newline at end of file
...@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase): ...@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, router_logits: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> ( ) -> tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, hidden_states,
topk_weights, router_logits,
topk_ids,
is_sequence_parallel, is_sequence_parallel,
extra_tensors=extra_tensors, extra_tensors, # type: ignore[call-arg]
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.combine( hidden_states = self.all2all_manager.combine(
hidden_states, hidden_states, is_sequence_parallel
is_sequence_parallel,
) )
return hidden_states
...@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata): ...@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1): class NixlConnector(KVConnectorBase_V1):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
......
...@@ -1003,7 +1003,7 @@ class GroupCoordinator: ...@@ -1003,7 +1003,7 @@ class GroupCoordinator:
if self.device_communicator is not None: if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model) self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -1014,7 +1014,7 @@ class GroupCoordinator: ...@@ -1014,7 +1014,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch_router_logits( return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
...@@ -1023,28 +1023,6 @@ class GroupCoordinator: ...@@ -1023,28 +1023,6 @@ class GroupCoordinator:
else: else:
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states, is_sequence_parallel: bool = False self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -346,16 +346,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: ...@@ -346,16 +346,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
cached version. cached version.
""" """
return copy.deepcopy(_compute_kwargs(cls)) return copy.deepcopy(_compute_kwargs(cls))
class EnvironmentConfigError(Exception):
pass
# def check_incompatible_config(env1: bool, env2: bool):
# if env1 is True and env2 is True:
# _s = "USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and USE_FUSED_RMS_QUANT must not be enabled simultaneously!\n\n"
# raise EnvironmentConfigError(_s)
@dataclass @dataclass
class EngineArgs: class EngineArgs:
...@@ -1038,7 +1029,6 @@ class EngineArgs: ...@@ -1038,7 +1029,6 @@ class EngineArgs:
) )
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
# Observability arguments # Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig) observability_kwargs = get_kwargs(ObservabilityConfig)
observability_group = parser.add_argument_group( observability_group = parser.add_argument_group(
...@@ -1646,8 +1636,6 @@ class EngineArgs: ...@@ -1646,8 +1636,6 @@ class EngineArgs:
target_model_config=model_config, target_model_config=model_config,
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
) )
# check_incompatible_config(envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT, envs.USE_FUSED_RMS_QUANT)
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
runner_type=model_config.runner_type, runner_type=model_config.runner_type,
...@@ -1789,7 +1777,6 @@ class EngineArgs: ...@@ -1789,7 +1777,6 @@ class EngineArgs:
return config return config
def _check_feature_supported(self, model_config: ModelConfig): def _check_feature_supported(self, model_config: ModelConfig):
"""Raise an error if the feature is not supported.""" """Raise an error if the feature is not supported."""
if self.logits_processor_pattern != EngineArgs.logits_processor_pattern: if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
......
...@@ -78,7 +78,6 @@ from vllm.v1.engine import EngineCoreRequest ...@@ -78,7 +78,6 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.metrics.reader import Metric from vllm.v1.metrics.reader import Metric
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import hashlib import hashlib
import importlib import importlib
import inspect import inspect
...@@ -265,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None: ...@@ -265,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None:
return None return None
def get_uvicorn_log_config(args: Namespace) -> dict | None:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
return log_config
# If endpoints to filter are specified, create a config with the filter
if args.disable_access_log_for_endpoints:
from vllm.logging_utils import create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths = [
p.strip()
for p in args.disable_access_log_for_endpoints.split(",")
if p.strip()
]
return create_uvicorn_log_config(
excluded_paths=excluded_paths,
log_level=args.uvicorn_log_level,
)
return None
class AuthenticationMiddleware: class AuthenticationMiddleware:
""" """
Pure ASGI middleware that authenticates each request by checking Pure ASGI middleware that authenticates each request by checking
...@@ -964,8 +930,8 @@ async def run_server_worker( ...@@ -964,8 +930,8 @@ async def run_server_worker(
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
# Get uvicorn log config (from file or with endpoint filter) # Load logging config for uvicorn if specified
log_config = get_uvicorn_log_config(args) log_config = load_log_config(args.log_config_file)
if log_config is not None: if log_config is not None:
uvicorn_kwargs["log_config"] = log_config uvicorn_kwargs["log_config"] = log_config
...@@ -1022,4 +988,4 @@ if __name__ == "__main__": ...@@ -1022,4 +988,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
validate_parsed_serve_args(args) validate_parsed_serve_args(args)
uvloop.run(run_server(args)) uvloop.run(run_server(args))
\ No newline at end of file
...@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
DeltaToolCall, DeltaToolCall,
ErrorResponse, ErrorResponse,
FunctionCall,
PromptTokenUsageInfo, PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
ToolCall, ToolCall,
...@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
...@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing):
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
if self.model_config.hf_config.model_type == "kimi_k2":
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony: if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params: if "stop_token_ids" not in self.default_sampling_params:
...@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing # NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the # for some models, currently vLLM doesn't support it. Please use the
# Responses API instead. # Responses API instead.
...@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type] maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request) # type: ignore[arg-type] truncate_tool_call_ids(request)
validate_request_params(request) validate_request_params(request)
# Check if tool parsing is unavailable (common condition) # Check if tool parsing is unavailable (common condition)
...@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing): ...@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
sub_request_id = ( sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
) )
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
request=request, request=request,
prompt=engine_prompt, input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params, default_sampling_params=self.default_sampling_params,
) )
...@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
tokenizer = self.renderer.tokenizer tokenizer = self.renderer.tokenizer
assert tokenizer is not None
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
...@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
...@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing):
) )
reasoning_parser = self.reasoning_parser( reasoning_parser = self.reasoning_parser(
tokenizer, tokenizer,
chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg] chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
) )
except RuntimeError as e: except RuntimeError as e:
logger.exception("Error in reasoning parser creation.") logger.exception("Error in reasoning parser creation.")
...@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing):
index=i, index=i,
) )
else: else:
# Generate ID based on tokenizer type
if isinstance(tokenizer, MistralTokenizer):
tool_call_id = MistralToolCall.generate_random_id()
else:
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_choice_function_name,
idx=history_tool_call_cnt,
)
delta_tool_call = DeltaToolCall( delta_tool_call = DeltaToolCall(
id=tool_call_id, id=make_tool_call_id(),
type="function", type="function",
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=tool_choice_function_name, name=tool_choice_function_name,
...@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput | None = None final_res: RequestOutput | None = None
...@@ -1543,85 +1524,39 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1543,85 +1524,39 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class = ( tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
) )
if self.use_harmony: if (not self.enable_auto_tools or not self.tool_parser) and (
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_calls if tool_calls else [],
)
elif (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required" and request.tool_choice != "required"
): ):
message = ChatMessage(role=role, reasoning=reasoning, content=content) message = ChatMessage(role=role, reasoning=reasoning, content=content)
# if the request uses tools and specified a tool choice
elif ( elif (
request.tool_choice request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
): ):
assert tool_calls is not None and len(tool_calls) > 0 assert tool_calls is not None and len(tool_calls) > 0
tool_call_class_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_class_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning=reasoning, reasoning=reasoning,
content="", content="",
tool_calls=tool_call_class_items, tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
) )
elif request.tool_choice and request.tool_choice == "required": elif request.tool_choice and request.tool_choice == "required":
tool_call_class_items = [] tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0 assert tool_calls is not None and len(tool_calls) > 0
for idx, tool_call in enumerate(tool_calls): for tool_call in tool_calls:
# Use native ID if available, tool_call_class_items.append(
# otherwise generate ID with correct id_type tool_call_class(
if tool_call.id: id=make_tool_call_id(
tool_call_class_items.append(
tool_call_class(id=tool_call.id, function=tool_call)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(
tool_call_class(function=tool_call)
)
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type, id_type=self.tool_call_id_type,
func_name=tool_call.name, func_name=tool_call.name,
idx=history_tool_call_cnt + idx, idx=history_tool_call_cnt,
) ),
tool_call_class_items.append( function=tool_call,
tool_call_class(id=generated_id, function=tool_call) )
) )
history_tool_call_cnt += 1 history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
...@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls # call. The same is not true for named function calls
auto_tools_called = tool_calls is not None and len(tool_calls) > 0 auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_calls: if tool_calls:
tool_call_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning=reasoning, reasoning=reasoning,
content=content, content=content,
tool_calls=tool_call_items, tool_calls=[
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
) )
else: else:
...@@ -1784,11 +1701,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1784,11 +1701,13 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls: elif choice.message.tool_calls:
# For tool calls, log the function name and arguments # For tool calls, log the function name and arguments
tool_call_descriptions = [] tool_call_descriptions = []
for tc in choice.message.tool_calls: # type: ignore for tc in choice.message.tool_calls:
function_call: FunctionCall = tc.function # type: ignore if hasattr(tc.function, "name") and hasattr(
tool_call_descriptions.append( tc.function, "arguments"
f"{function_call.name}({function_call.arguments})" ):
) tool_call_descriptions.append(
f"{tc.function.name}({tc.function.arguments})"
)
tool_calls_str = ", ".join(tool_call_descriptions) tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]" output_text = f"[tool_calls: {tool_calls_str}]"
...@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type] maybe_serialize_tool_calls(request)
# Add system message. # Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default # NOTE: In Chat Completion API, browsing is enabled by default
...@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message. # Add developer message.
if request.tools: if request.tools:
dev_msg = get_developer_message( dev_msg = get_developer_message(
tools=request.tools if should_include_tools else None # type: ignore[arg-type] tools=request.tools if should_include_tools else None
) )
messages.append(dev_msg) messages.append(dev_msg)
...@@ -2009,4 +1928,4 @@ class OpenAIServingChat(OpenAIServing): ...@@ -2009,4 +1928,4 @@ class OpenAIServingChat(OpenAIServing):
if request.cache_salt is not None: if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt engine_prompt["cache_salt"] = request.cache_salt
return messages, [engine_prompt] return messages, [engine_prompt]
\ No newline at end of file
...@@ -85,12 +85,6 @@ class FrontendArgs: ...@@ -85,12 +85,6 @@ class FrontendArgs:
"""Log level for uvicorn.""" """Log level for uvicorn."""
disable_uvicorn_access_log: bool = False disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log.""" """Disable uvicorn access log."""
disable_access_log_for_endpoints: str | None = None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials: bool = False allow_credentials: bool = False
"""Allow credentials.""" """Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"]) allowed_origins: list[str] = field(default_factory=lambda: ["*"])
...@@ -250,11 +244,6 @@ class FrontendArgs: ...@@ -250,11 +244,6 @@ class FrontendArgs:
del frontend_kwargs["middleware"]["nargs"] del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = [] frontend_kwargs["middleware"]["default"] = []
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]:
del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"]
# Special case: Tool call parser shows built-in options. # Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered()) valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers) parsers_str = ",".join(valid_tool_parsers)
...@@ -332,4 +321,4 @@ def create_parser_for_docs() -> FlexibleArgumentParser: ...@@ -332,4 +321,4 @@ def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser( parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server" prog="-m vllm.entrypoints.openai.api_server"
) )
return make_arg_parser(parser_for_docs) return make_arg_parser(parser_for_docs)
\ No newline at end of file
...@@ -36,7 +36,6 @@ from vllm.entrypoints.renderer import RenderConfig ...@@ -36,7 +36,6 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -163,12 +162,25 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -163,12 +162,25 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
raise NotImplementedError
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
request=request, request=request,
prompt=engine_prompt, input_length=input_length,
default_sampling_params=self.default_sampling_params, default_sampling_params=self.default_sampling_params,
) )
...@@ -731,4 +743,4 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -731,4 +743,4 @@ class OpenAIServingCompletion(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt, cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids), needs_detokenization=bool(request.echo and not request.return_token_ids),
) )
\ No newline at end of file
...@@ -218,10 +218,6 @@ def get_logits_processors( ...@@ -218,10 +218,6 @@ def get_logits_processors(
class FunctionCall(OpenAIBaseModel): class FunctionCall(OpenAIBaseModel):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id: str | None = Field(default=None, exclude=True)
name: str name: str
arguments: str arguments: str
...@@ -319,4 +315,4 @@ class GenerateRequest(BaseModel): ...@@ -319,4 +315,4 @@ class GenerateRequest(BaseModel):
kv_transfer_params: dict[str, Any] | None = Field( kv_transfer_params: dict[str, Any] | None = Field(
default=None, default=None,
description="KVTransfer parameters used for disaggregated serving.", description="KVTransfer parameters used for disaggregated serving.",
) )
\ No newline at end of file
...@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import ( ...@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import (
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
ClassificationCompletionRequest, ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse, ClassificationResponse,
) )
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
) )
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
...@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
) )
from vllm.entrypoints.utils import ( from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import ( from vllm.inputs.parse import (
PromptComponents,
get_prompt_components, get_prompt_components,
is_explicit_encoder_decoder_prompt, is_explicit_encoder_decoder_prompt,
) )
...@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = ( ...@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = (
CompletionResponse CompletionResponse
| ChatCompletionResponse | ChatCompletionResponse
| EmbeddingResponse | EmbeddingResponse
| EmbeddingBytesResponse
| TranscriptionResponse | TranscriptionResponse
| TokenizeResponse | TokenizeResponse
| PoolingResponse | PoolingResponse
...@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest) ...@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@dataclass(kw_only=True) @dataclass(kw_only=True)
class ServeContext(Generic[RequestT]): class RequestProcessingMixin:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
request: RequestT request: RequestT
raw_request: Request | None = None raw_request: Request | None = None
model_name: str model_name: str
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) @dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
class OpenAIServing: class OpenAIServing:
...@@ -578,7 +605,10 @@ class OpenAIServing: ...@@ -578,7 +605,10 @@ class OpenAIServing:
self, self,
ctx: ServeContext, ctx: ServeContext,
) -> AnyResponse | ErrorResponse: ) -> AnyResponse | ErrorResponse:
async for response in self._pipeline(ctx): generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
generation = self._pipeline(ctx)
async for response in generation:
return response return response
return self.create_error_response("No response yielded from pipeline") return self.create_error_response("No response yielded from pipeline")
...@@ -637,7 +667,9 @@ class OpenAIServing: ...@@ -637,7 +667,9 @@ class OpenAIServing:
ctx: ServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Schedule the request and get the result generator.""" """Schedule the request and get the result generator."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -691,7 +723,7 @@ class OpenAIServing: ...@@ -691,7 +723,7 @@ class OpenAIServing:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None] final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
if ctx.result_generator is None: if ctx.result_generator is None:
...@@ -949,7 +981,6 @@ class OpenAIServing: ...@@ -949,7 +981,6 @@ class OpenAIServing:
max_length=truncate_prompt_tokens, max_length=truncate_prompt_tokens,
) )
input_ids = encoded.input_ids input_ids = encoded.input_ids
input_text = prompt input_text = prompt
...@@ -973,14 +1004,14 @@ class OpenAIServing: ...@@ -973,14 +1004,14 @@ class OpenAIServing:
if tokenizer is None: if tokenizer is None:
input_text = "" input_text = ""
else: else:
async_tokenizer = self._get_async_tokenizer(tokenizer) async_tokenizer = self._get_async_tokenizer(tokenizer)
input_text = await async_tokenizer.decode(input_ids) input_text = await async_tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text) return self._validate_input(request, input_ids, input_text)
def _validate_input( def _validate_input(
self, self,
request: object, request: AnyRequest,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
...@@ -1291,7 +1322,7 @@ class OpenAIServing: ...@@ -1291,7 +1322,7 @@ class OpenAIServing:
priority: int = 0, priority: int = 0,
**kwargs, **kwargs,
): ):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
...@@ -1342,12 +1373,10 @@ class OpenAIServing: ...@@ -1342,12 +1373,10 @@ class OpenAIServing:
# yield context # yield context
# Create inputs for the next turn. # Create inputs for the next turn.
# Render the next prompt token ids and update sampling_params. # Render the next prompt token ids.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion() prompt_token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
sampling_params.max_tokens = self.max_model_len - len(token_ids)
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn( engine_prompts = await self._render_next_turn(
context.request, context.request,
...@@ -1359,19 +1388,19 @@ class OpenAIServing: ...@@ -1359,19 +1388,19 @@ class OpenAIServing:
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
engine_prompt,
self.default_sampling_params, # type: ignore
)
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION # OPTIMIZATION
priority = orig_priority - 1 priority = orig_priority - 1
sub_request += 1 sub_request += 1
def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
return get_prompt_components(prompt)
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
...@@ -1382,7 +1411,7 @@ class OpenAIServing: ...@@ -1382,7 +1411,7 @@ class OpenAIServing:
if self.request_logger is None: if self.request_logger is None:
return return
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs) prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
self.request_logger.log_inputs( self.request_logger.log_inputs(
request_id, request_id,
...@@ -1496,7 +1525,6 @@ class OpenAIServing: ...@@ -1496,7 +1525,6 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls. # extract_tool_calls() returns a list of tool calls.
function_calls.extend( function_calls.extend(
FunctionCall( FunctionCall(
id=tool_call.id,
name=tool_call.function.name, name=tool_call.function.name,
arguments=tool_call.function.arguments, arguments=tool_call.function.arguments,
) )
...@@ -1549,4 +1577,4 @@ def clamp_prompt_logprobs( ...@@ -1549,4 +1577,4 @@ def clamp_prompt_logprobs(
for logprob_values in logprob_dict.values(): for logprob_values in logprob_dict.values():
if logprob_values.logprob == float("-inf"): if logprob_values.logprob == float("-inf"):
logprob_values.logprob = -9999.0 logprob_values.logprob = -9999.0
return prompt_logprobs return prompt_logprobs
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment