Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING
import torch
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
self._lmcache_engine.wait_for_layer_load(layer_name)
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""
Start saving the a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
**kwargs)
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
self._lmcache_engine.wait_for_save()
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
return self._lmcache_engine.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
self._lmcache_engine.update_state_after_alloc(request,
num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
return self._lmcache_engine.build_connector_meta(scheduler_output)
# SPDX-License-Identifier: Apache-2.0
import hashlib
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
@staticmethod
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
is_store: bool) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
class SharedStorageConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
transfer_config = vllm_config.kv_transfer_config
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp")
logger.info(vllm_config.kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
logger.warning(
"In connector.start_load_kv, but the connector metadata is None"
)
return
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
logger.warning(
"In connector.start_load_kv, but the attn_metadata is None")
return
# Load the KV for each request each layer
for request in metadata.requests:
if request.is_store:
continue
logger.info("Inject KV cache of %d tokens to the paged memory",
len(request.slot_mapping))
for layer_name in forward_context.no_compile_layers:
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache_layer = attn_layer.kv_cache[\
forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids)
kv_cache = safetensors.torch.load_file(
filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids)
kv_cache = extract_kv_from_layer(kv_layer,
request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
def wait_for_save(self):
return
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0
logger.info("External Cache Hit!")
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens
def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
if num_external_tokens > 0:
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self._requests_need_load:
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_size=self._block_size,
is_store=False)
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_request(new_req):
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_size=self._block_size,
is_store=True)
for cached_req in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption:
break
if cached_req.req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[cached_req.req_id]
total_tokens = (len(cached_req.new_token_ids) +
cached_req.num_computed_tokens)
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = cached_req.new_block_ids
meta.add_request(token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False)
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_request(
self,
request: "Request",
) -> bool:
"""Check if the cache is hit for the request.
"""
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
foldername = self._generate_foldername_debug(torch.tensor(
request.prompt_token_ids)[:num_tokens_to_check],
create_folder=False)
return os.path.exists(foldername)
def _generate_foldername_debug(
self,
input_ids: torch.Tensor,
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
input_ids_bytes = input_ids.numpy().tobytes()
input_ids_hash = hashlib.md5(input_ids_bytes,
usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(
self,
layer_name: str,
input_ids: torch.Tensor,
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(input_ids,
create_folder=True)
return os.path.join(foldername, f"{layer_name}.safetensors")
def align_to_block_size(num_tokens: int, block_size) -> int:
"""Align the number of tokens to the block size.
"""
return (num_tokens - 1) // block_size * block_size
...@@ -46,7 +46,7 @@ class KVTransferAgent: ...@@ -46,7 +46,7 @@ class KVTransferAgent:
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
"TransferAgent should only be used when kv_connector is set." "TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector( self.connector = KVConnectorFactory.create_connector_v0(
rank, local_rank, config) rank, local_rank, config)
def send_kv_caches_and_hidden_states( def send_kv_caches_and_hidden_states(
......
...@@ -70,7 +70,7 @@ class MooncakeStore(KVStoreBufferBase): ...@@ -70,7 +70,7 @@ class MooncakeStore(KVStoreBufferBase):
): ):
try: try:
from mooncake_vllm_adaptor import MooncakeDistributedStore from mooncake.store import MooncakeDistributedStore
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Please install mooncake by following the instructions at " "Please install mooncake by following the instructions at "
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import json import json
import os import os
import struct
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
...@@ -57,14 +58,14 @@ class MooncakeTransferEngine: ...@@ -57,14 +58,14 @@ class MooncakeTransferEngine:
def __init__(self, kv_rank: int, local_rank: int): def __init__(self, kv_rank: int, local_rank: int):
try: try:
import mooncake_vllm_adaptor as mva from mooncake.engine import TransferEngine
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Please install mooncake by following the instructions at " "Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e "to run vLLM with MooncakeConnector.") from e
self.engine = mva.mooncake_vllm_adaptor() self.engine = TransferEngine()
self.local_rank = local_rank self.local_rank = local_rank
try: try:
...@@ -115,14 +116,14 @@ class MooncakeTransferEngine: ...@@ -115,14 +116,14 @@ class MooncakeTransferEngine:
p_rank_offset = int(p_port) + 8 + self.local_rank * 2 p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2 d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0: if kv_rank == 0:
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}") self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}") self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
else: else:
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}") self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}") self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
def initialize(self, local_hostname: str, metadata_server: str, def initialize(self, local_hostname: str, metadata_server: str,
...@@ -140,12 +141,12 @@ class MooncakeTransferEngine: ...@@ -140,12 +141,12 @@ class MooncakeTransferEngine:
"Mooncake Configuration error. `metadata_backend`" "Mooncake Configuration error. `metadata_backend`"
f" should be one of {supported_backend}.") f" should be one of {supported_backend}.")
self.engine.initializeExt(local_hostname, metadata_server, self.engine.initialize_ext(local_hostname, metadata_server,
protocol, device_name, metadata_backend) protocol, device_name, metadata_backend)
def allocate_managed_buffer(self, length: int) -> int: def allocate_managed_buffer(self, length: int) -> int:
"""Allocate a managed buffer of the specified length.""" """Allocate a managed buffer of the specified length."""
ret = self.engine.allocateManagedBuffer(length) ret = self.engine.allocate_managed_buffer(length)
if ret <= 0: if ret <= 0:
logger.error("Allocation Return Error") logger.error("Allocation Return Error")
raise Exception("Allocation Return Error") raise Exception("Allocation Return Error")
...@@ -153,13 +154,13 @@ class MooncakeTransferEngine: ...@@ -153,13 +154,13 @@ class MooncakeTransferEngine:
def free_managed_buffer(self, buffer: int, length: int) -> int: def free_managed_buffer(self, buffer: int, length: int) -> int:
"""Free a previously allocated managed buffer.""" """Free a previously allocated managed buffer."""
return self.engine.freeManagedBuffer(buffer, length) return self.engine.free_managed_buffer(buffer, length)
def transfer_sync(self, buffer: int, peer_buffer_address: int, def transfer_sync(self, buffer: int, peer_buffer_address: int,
length: int) -> int: length: int) -> int:
"""Synchronously transfer data to the specified address.""" """Synchronously transfer data to the specified address."""
ret = self.engine.transferSync(self.remote_url, buffer, ret = self.engine.transfer_sync_read(self.remote_url, buffer,
peer_buffer_address, length) peer_buffer_address, length)
if ret < 0: if ret < 0:
logger.error("Transfer Return Error") logger.error("Transfer Return Error")
raise Exception("Transfer Return Error") raise Exception("Transfer Return Error")
...@@ -168,15 +169,15 @@ class MooncakeTransferEngine: ...@@ -168,15 +169,15 @@ class MooncakeTransferEngine:
def write_bytes_to_buffer(self, buffer: int, user_data: bytes, def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
length: int) -> int: length: int) -> int:
"""Write bytes to the allocated buffer.""" """Write bytes to the allocated buffer."""
return self.engine.writeBytesToBuffer(buffer, user_data, length) return self.engine.write_bytes_to_buffer(buffer, user_data, length)
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
"""Read bytes from the allocated buffer.""" """Read bytes from the allocated buffer."""
return self.engine.readBytesFromBuffer(buffer, length) return self.engine.read_bytes_from_buffer(buffer, length)
def wait_for_ack(self, src_ptr: int, length: int) -> None: def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver.""" """Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv_pyobj() ack = self.sender_ack.recv()
if ack != b'ACK': if ack != b'ACK':
logger.error("Failed to receive ACK from the receiver") logger.error("Failed to receive ACK from the receiver")
...@@ -187,18 +188,22 @@ class MooncakeTransferEngine: ...@@ -187,18 +188,22 @@ class MooncakeTransferEngine:
length = len(user_data) length = len(user_data)
src_ptr = self.allocate_managed_buffer(length) src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length) self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_pyobj((src_ptr, length)) self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr),
struct.pack("!Q", length)])
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes: def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process.""" """Receive bytes from the remote process."""
src_ptr, length = self.receiver_socket.recv_pyobj() data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length) dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length) self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length) ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup # Buffer cleanup
self.receiver_ack.send_pyobj(b'ACK') self.receiver_ack.send(b'ACK')
self.free_managed_buffer(dst_ptr, length) self.free_managed_buffer(dst_ptr, length)
return ret return ret
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
if TYPE_CHECKING:
from vllm.config import VllmConfig
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
def get_kv_transfer_group() -> KVConnectorBaseType:
assert _KV_CONNECTOR_AGENT is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_CONNECTOR_AGENT
def has_kv_transfer_group() -> bool:
return _KV_CONNECTOR_AGENT is not None
def is_v1_kv_transfer_group(
connector: Optional[KVConnectorBaseType] = None) -> bool:
"""Check if the KV connector is the v1 connector.
If the argument is None, it will check the global KV connector
Args:
connector: The KV connector to check. If None, it will check the
global KV connector.
Note:
This function will no-longer be needed after the v1 KV connector
becomes the default.
"""
if connector is None:
connector = _KV_CONNECTOR_AGENT
if connector is None:
return False
return isinstance(connector, KVConnectorBase_V1)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_CONNECTOR_AGENT
if vllm_config.kv_transfer_config is None:
return
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
and _KV_CONNECTOR_AGENT is None):
if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
config=vllm_config, role=KVConnectorRole.WORKER)
else:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config,
)
...@@ -29,15 +29,13 @@ from collections import namedtuple ...@@ -29,15 +29,13 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Union)
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import ( from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase) DeviceCommunicatorBase)
...@@ -46,9 +44,6 @@ from vllm.logger import init_logger ...@@ -46,9 +44,6 @@ from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op) supports_custom_op)
if TYPE_CHECKING:
from vllm.config import VllmConfig
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
...@@ -118,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: ...@@ -118,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor) return torch.empty_like(tensor)
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.reduce_scatter(tensor, dim)
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] // world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.all_gather(tensor, dim)
def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] * world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
if supports_custom_op(): if supports_custom_op():
from vllm.platforms import current_platform from vllm.platforms import current_platform
direct_register_custom_op( direct_register_custom_op(
...@@ -128,6 +155,20 @@ if supports_custom_op(): ...@@ -128,6 +155,20 @@ if supports_custom_op():
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
)
direct_register_custom_op(
op_name="all_gather",
op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake,
)
class GroupCoordinator: class GroupCoordinator:
""" """
...@@ -327,6 +368,18 @@ class GroupCoordinator: ...@@ -327,6 +368,18 @@ class GroupCoordinator:
return self.device_communicator.all_gather(input_, dim) return self.device_communicator.all_gather(input_, dim)
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
return self.device_communicator.reduce_scatter(input_, dim)
def gather(self, def gather(self,
input_: torch.Tensor, input_: torch.Tensor,
dst: int = 0, dst: int = 0,
...@@ -772,14 +825,6 @@ def get_pp_group() -> GroupCoordinator: ...@@ -772,14 +825,6 @@ def get_pp_group() -> GroupCoordinator:
# kept for backward compatibility # kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group get_pipeline_model_parallel_group = get_pp_group
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
assert _KV_TRANSFER is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_TRANSFER
@contextmanager @contextmanager
def graph_capture(device: torch.device): def graph_capture(device: torch.device):
...@@ -962,26 +1007,6 @@ def initialize_model_parallel( ...@@ -962,26 +1007,6 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_TRANSFER
if vllm_config.kv_transfer_config is None:
return
if all([
vllm_config.kv_transfer_config.is_kv_transfer_instance,
_KV_TRANSFER is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config)
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import dataclasses import dataclasses
import datetime import datetime
import pickle import pickle
import socket
import time import time
from collections import deque from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple from typing import Any, Deque, Dict, Optional, Sequence, Tuple
...@@ -123,6 +124,10 @@ class StatelessProcessGroup: ...@@ -123,6 +124,10 @@ class StatelessProcessGroup:
rank: int rank: int
world_size: int world_size: int
store: torch._C._distributed_c10d.Store store: torch._C._distributed_c10d.Store
# stores a reference to the socket so that the file descriptor stays alive
socket: Optional[socket.socket]
data_expiration_seconds: int = 3600 # 1 hour data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter # dst rank -> counter
...@@ -234,18 +239,33 @@ class StatelessProcessGroup: ...@@ -234,18 +239,33 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B, can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group. C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa """ # noqa
launch_server = rank == 0
if launch_server:
# listen on the specified interface (instead of 0.0.0.0)
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind((host, port))
listen_socket.listen()
listen_fd = listen_socket.fileno()
else:
listen_socket = None
listen_fd = None
store = TCPStore( store = TCPStore(
host_name=host, host_name=host,
port=port, port=port,
world_size=world_size, world_size=world_size,
is_master=(rank == 0), is_master=launch_server,
timeout=datetime.timedelta(seconds=store_timeout), timeout=datetime.timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
) )
return StatelessProcessGroup( return StatelessProcessGroup(
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
store=store, store=store,
socket=listen_socket,
data_expiration_seconds=data_expiration_seconds) data_expiration_seconds=data_expiration_seconds)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# yapf: disable
import argparse import argparse
import dataclasses import dataclasses
import json import json
import re import re
import threading import threading
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
Tuple, Type, Union, cast, get_args, get_origin) TypeVar, Union, cast, get_args, get_origin)
import torch import torch
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm import version from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DecodingConfig, DeviceConfig, HfOverrides, ConfigFormat, ConfigType, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackendV1, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig, ModelConfig, ModelImpl, MultiModalConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig, ObservabilityConfig, ParallelConfig, PoolerConfig,
SchedulerConfig, SpeculativeConfig, TaskOption, PrefixCachingHashAlgo, PromptAdapterConfig,
TokenizerPoolConfig, VllmConfig, get_attr_docs) SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
...@@ -28,33 +34,42 @@ from vllm.reasoning import ReasoningParserManager ...@@ -28,33 +34,42 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
if TYPE_CHECKING: # yapf: enable
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
logger = init_logger(__name__) logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
DEVICE_OPTIONS = [ # object is used to allow for special typing forms
"auto", T = TypeVar("T")
"cuda", TypeHint = Union[type[Any], object]
"neuron", TypeHintT = Union[type[T], object]
"cpu",
"tpu",
"xpu",
"hpu",
]
def nullable_str(val: str): def optional_type(
if not val or val == "None": return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
return None
return val def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _optional_type
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: @deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
"string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
"""Parses a string containing comma separate key [str] to value [int] """Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary. pairs into a dictionary.
...@@ -64,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: ...@@ -64,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
Returns: Returns:
Dictionary with parsed values. Dictionary with parsed values.
""" """
if len(val) == 0: out_dict: dict[str, int] = {}
return None
out_dict: Dict[str, int] = {}
for item in val.split(","): for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")] kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2: if len(kv_parts) != 2:
...@@ -89,6 +101,105 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: ...@@ -89,6 +101,105 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
return out_dict return out_dict
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the type hint is a specific type."""
return type_hint is type or get_origin(type_hint) is type
def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
"""Check if the type hints contain a specific type."""
return any(is_type(type_hint, type) for type_hint in type_hints)
def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
"""Get the specific type from the type hints."""
return next((th for th in type_hints if is_type(th, type)), None)
def is_not_builtin(type_hint: TypeHint) -> bool:
"""Check if the class is not a built-in type."""
return type_hint.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
if contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
# Creates choices from Literal arguments
type_hint = get_type(type_hints, Literal)
choices = sorted(get_args(type_hint))
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}")
kwargs[name]["type"] = choice_type
elif contains_type(type_hints, tuple):
type_hint = get_type(type_hints, tuple)
types = get_args(type_hint)
tuple_type = types[0]
assert all(t is tuple_type for t in types if t is not Ellipsis), (
"All non-Ellipsis tuple elements must be of the same "
f"type. Got {types}.")
kwargs[name]["type"] = tuple_type
kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
elif contains_type(type_hints, list):
type_hint = get_type(type_hints, list)
types = get_args(type_hint)
assert len(types) == 1, (
"List type must have exactly one type. Got "
f"{type_hint} with types {types}")
kwargs[name]["type"] = types[0]
kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int):
kwargs[name]["type"] = int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str
else:
raise ValueError(
f"Unsupported type {type_hints} for argument {name}.")
# If None is in type_hints, make the argument optional.
# But not if it's a bool, argparse will handle this better.
if type(None) in type_hints and not contains_type(type_hints, bool):
kwargs[name]["type"] = optional_type(kwargs[name]["type"])
if kwargs[name].get("choices"):
kwargs[name]["choices"].append("None")
return kwargs
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
...@@ -105,14 +216,15 @@ class EngineArgs: ...@@ -105,14 +216,15 @@ class EngineArgs:
load_format: str = LoadConfig.load_format load_format: str = LoadConfig.load_format
config_format: ConfigFormat = ConfigFormat.AUTO config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto' kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: Optional[int] = None seed: Optional[int] = None
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
# Note: Specifying a custom executor backend by passing a class # Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
distributed_executor_backend: Optional[Union[ distributed_executor_backend: Optional[Union[
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend DistributedExecutorBackend,
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
...@@ -120,20 +232,23 @@ class EngineArgs: ...@@ -120,20 +232,23 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[ max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[int] = None block_size: Optional[BlockSize] = CacheConfig.block_size
enable_prefix_caching: Optional[bool] = None enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
prefix_caching_hash_algo: str = "builtin" prefix_caching_hash_algo: PrefixCachingHashAlgo = \
CacheConfig.prefix_caching_hash_algo
disable_sliding_window: bool = False disable_sliding_window: bool = False
disable_cascade_attn: bool = False disable_cascade_attn: bool = False
use_v2_block_manager: bool = True use_v2_block_manager: bool = True
swap_space: float = 4 # GiB swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = 0 # GiB cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[
max_num_partial_prefills: Optional[int] = 1 int] = SchedulerConfig.max_num_batched_tokens
max_long_partial_prefills: Optional[int] = 1 max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
long_prefill_token_threshold: Optional[int] = 0 max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
max_num_seqs: Optional[int] = None long_prefill_token_threshold: int = \
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
...@@ -147,44 +262,52 @@ class EngineArgs: ...@@ -147,44 +262,52 @@ class EngineArgs:
enforce_eager: Optional[bool] = None enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192 max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0 # The following three fields are deprecated and will be removed in a future
# Note: Specifying a tokenizer pool by passing a class # release. Setting them will have no effect. Please remove them from your
# is intended for expert use only. The API may change without # configurations.
# notice. tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None tokenizer_pool_extra_config: dict = \
limit_mm_per_prompt: Optional[Mapping[str, int]] = None get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
# LoRA fields
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = 1 max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = 16 max_lora_rank: int = LoRAConfig.max_lora_rank
enable_prompt_adapter: bool = False fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_prompt_adapters: int = 1 max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
merge_lora: bool = False merge_lora: bool = False
lora_target_modules: Optional[List[str]] = None lora_target_modules: Optional[List[str]] = LoRAConfig.lora_target_modules
device: str = 'auto' lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
num_scheduler_steps: int = 1 lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
multi_step_stream_outputs: bool = True long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[
num_lookahead_slots: int = 0 int] = CacheConfig.num_gpu_blocks_override
model_loader_extra_config: Optional[ num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
dict] = LoadConfig.model_loader_extra_config model_loader_extra_config: dict = \
get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str, ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = None preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = SchedulerConfig.delay_factor
enable_chunked_prefill: Optional[bool] = None enable_chunked_prefill: Optional[
disable_chunked_mm_input: bool = False bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
...@@ -197,8 +320,8 @@ class EngineArgs: ...@@ -197,8 +320,8 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None override_pooler_config: Optional[PoolerConfig] = None
...@@ -213,11 +336,11 @@ class EngineArgs: ...@@ -213,11 +336,11 @@ class EngineArgs:
enable_sleep_mode: bool = False enable_sleep_mode: bool = False
model_impl: str = "auto" model_impl: str = "auto"
calculate_kv_scales: Optional[bool] = None calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
additional_config: Optional[Dict[str, Any]] = None additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
...@@ -240,38 +363,6 @@ class EngineArgs: ...@@ -240,38 +363,6 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
"""Check if the class is a type in a union type."""
return get_origin(cls) is Union and type in get_args(cls)
def is_optional(cls: type[Any]) -> bool:
"""Check if the class is an optional type."""
return is_type_in_union(cls, type(None))
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# One of these will always be present
default = (field.default_factory
if field.default is MISSING else field.default)
kwargs[name] = {"default": default, "help": cls_docs[name]}
# When using action="store_true"
# add_argument doesn't accept type
if field.type is bool:
continue
# Handle optional fields
if is_optional(field.type):
kwargs[name]["type"] = nullable_str
continue
# Handle str in union fields
if is_type_in_union(field.type, str):
kwargs[name]["type"] = str
continue
kwargs[name]["type"] = field.type
return kwargs
# Model arguments # Model arguments
parser.add_argument( parser.add_argument(
'--model', '--model',
...@@ -289,13 +380,13 @@ class EngineArgs: ...@@ -289,13 +380,13 @@ class EngineArgs:
'which task to use.') 'which task to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=nullable_str, type=optional_type(str),
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. ' help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.') 'If unspecified, model name or path will be used.')
parser.add_argument( parser.add_argument(
"--hf-config-path", "--hf-config-path",
type=nullable_str, type=optional_type(str),
default=EngineArgs.hf_config_path, default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. ' help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.') 'If unspecified, model name or path will be used.')
...@@ -307,21 +398,21 @@ class EngineArgs: ...@@ -307,21 +398,21 @@ class EngineArgs:
'the input. The generated output will contain token ids.') 'the input. The generated output will contain token ids.')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=nullable_str, type=optional_type(str),
default=None, default=None,
help='The specific model version to use. It can be a branch ' help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument( parser.add_argument(
'--code-revision', '--code-revision',
type=nullable_str, type=optional_type(str),
default=None, default=None,
help='The specific revision to use for the model code on ' help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.') 'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=nullable_str, type=optional_type(str),
default=None, default=None,
help='Revision of the huggingface tokenizer to use. ' help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. ' 'It can be a branch name, a tag name, or a commit id. '
...@@ -361,7 +452,6 @@ class EngineArgs: ...@@ -361,7 +452,6 @@ class EngineArgs:
load_group.add_argument('--model-loader-extra-config', load_group.add_argument('--model-loader-extra-config',
**load_kwargs["model_loader_extra_config"]) **load_kwargs["model_loader_extra_config"])
load_group.add_argument('--use-tqdm-on-load', load_group.add_argument('--use-tqdm-on-load',
action=argparse.BooleanOptionalAction,
**load_kwargs["use_tqdm_on_load"]) **load_kwargs["use_tqdm_on_load"])
parser.add_argument( parser.add_argument(
...@@ -386,14 +476,6 @@ class EngineArgs: ...@@ -386,14 +476,6 @@ class EngineArgs:
'* "bfloat16" for a balance between precision and range.\n' '* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n' '* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.') '* "float32" for FP32 precision.')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=human_readable_int, type=human_readable_int,
default=EngineArgs.max_model_len, default=EngineArgs.max_model_len,
...@@ -403,21 +485,25 @@ class EngineArgs: ...@@ -403,21 +485,25 @@ class EngineArgs:
'Examples:\n' 'Examples:\n'
'- 1k → 1000\n' '- 1k → 1000\n'
'- 1K → 1024\n') '- 1K → 1024\n')
parser.add_argument(
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
guided_decoding_group = parser.add_argument_group(
title="DecodingConfig",
description=DecodingConfig.__doc__,
)
guided_decoding_group.add_argument(
'--guided-decoding-backend', '--guided-decoding-backend',
type=str, **guided_decoding_kwargs["guided_decoding_backend"])
default=DecodingConfig.guided_decoding_backend, guided_decoding_group.add_argument(
help='Which engine will be used for guided decoding' "--reasoning-parser",
' (JSON schema / regex etc) by default. Currently support ' # This choices is a special case because it's not static
'https://github.com/mlc-ai/xgrammar and ' choices=list(ReasoningParserManager.reasoning_parsers),
'https://github.com/guidance-ai/llguidance.' **guided_decoding_kwargs["reasoning_backend"])
'Valid backend values are "xgrammar", "guidance", and "auto". '
'With "auto", we will make opinionated choices based on request '
'contents and what the backend libraries currently support, so '
'the behavior is subject to change in each release.')
parser.add_argument( parser.add_argument(
'--logits-processor-pattern', '--logits-processor-pattern',
type=nullable_str, type=optional_type(str),
default=None, default=None,
help='Optional regex pattern specifying valid logits processor ' help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` ' 'qualified names that can be passed with the `logits_processors` '
...@@ -443,7 +529,6 @@ class EngineArgs: ...@@ -443,7 +529,6 @@ class EngineArgs:
) )
parallel_group.add_argument( parallel_group.add_argument(
'--distributed-executor-backend', '--distributed-executor-backend',
choices=['ray', 'mp', 'uni', 'external_launcher'],
**parallel_kwargs["distributed_executor_backend"]) **parallel_kwargs["distributed_executor_backend"])
parallel_group.add_argument( parallel_group.add_argument(
'--pipeline-parallel-size', '-pp', '--pipeline-parallel-size', '-pp',
...@@ -454,46 +539,40 @@ class EngineArgs: ...@@ -454,46 +539,40 @@ class EngineArgs:
**parallel_kwargs["data_parallel_size"]) **parallel_kwargs["data_parallel_size"])
parallel_group.add_argument( parallel_group.add_argument(
'--enable-expert-parallel', '--enable-expert-parallel',
action='store_true',
**parallel_kwargs["enable_expert_parallel"]) **parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument( parallel_group.add_argument(
'--max-parallel-loading-workers', '--max-parallel-loading-workers',
**parallel_kwargs["max_parallel_loading_workers"]) **parallel_kwargs["max_parallel_loading_workers"])
parallel_group.add_argument( parallel_group.add_argument(
'--ray-workers-use-nsight', '--ray-workers-use-nsight',
action='store_true',
**parallel_kwargs["ray_workers_use_nsight"]) **parallel_kwargs["ray_workers_use_nsight"])
parallel_group.add_argument( parallel_group.add_argument(
'--disable-custom-all-reduce', '--disable-custom-all-reduce',
action='store_true',
**parallel_kwargs["disable_custom_all_reduce"]) **parallel_kwargs["disable_custom_all_reduce"])
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 64, 128],
help='Token block size for contiguous chunks of '
'tokens. This is ignored on neuron devices and '
'set to ``--max-model-len``. On CUDA devices, '
'only block sizes up to 32 are supported. '
'On HPU devices, block size defaults to 128.')
parser.add_argument( # KV cache arguments
"--enable-prefix-caching", cache_kwargs = get_kwargs(CacheConfig)
action=argparse.BooleanOptionalAction, cache_group = parser.add_argument_group(
default=EngineArgs.enable_prefix_caching, title="CacheConfig",
help="Enables automatic prefix caching. " description=CacheConfig.__doc__,
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
)
parser.add_argument(
"--prefix-caching-hash-algo",
type=str,
choices=["builtin", "sha256"],
default=EngineArgs.prefix_caching_hash_algo,
help="Set the hash algorithm for prefix caching. "
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
"(collision resistant but with certain overheads).",
) )
cache_group.add_argument('--block-size', **cache_kwargs["block_size"])
cache_group.add_argument('--gpu-memory-utilization',
**cache_kwargs["gpu_memory_utilization"])
cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"])
cache_group.add_argument('--kv-cache-dtype',
**cache_kwargs["cache_dtype"])
cache_group.add_argument('--num-gpu-blocks-override',
**cache_kwargs["num_gpu_blocks_override"])
cache_group.add_argument("--enable-prefix-caching",
**cache_kwargs["enable_prefix_caching"])
cache_group.add_argument("--prefix-caching-hash-algo",
**cache_kwargs["prefix_caching_hash_algo"])
cache_group.add_argument('--cpu-offload-gb',
**cache_kwargs["cpu_offload_gb"])
cache_group.add_argument('--calculate-kv-scales',
**cache_kwargs["calculate_kv_scales"])
parser.add_argument('--disable-sliding-window', parser.add_argument('--disable-sliding-window',
action='store_true', action='store_true',
help='Disables sliding window, ' help='Disables sliding window, '
...@@ -506,86 +585,11 @@ class EngineArgs: ...@@ -506,86 +585,11 @@ class EngineArgs:
'block manager v2) is now the default. ' 'block manager v2) is now the default. '
'Setting this flag to True or False' 'Setting this flag to True or False'
' has no effect on vLLM behavior.') ' has no effect on vLLM behavior.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
default=EngineArgs.num_lookahead_slots,
help='Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.')
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=EngineArgs.seed, default=EngineArgs.seed,
help='Random seed for operations.') help='Random seed for operations.')
parser.add_argument('--swap-space',
type=float,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--cpu-offload-gb',
type=float,
default=0,
help='The space in GiB to offload to CPU, per GPU. '
'Default is 0, which means no offloading. Intuitively, '
'this argument can be seen as a virtual way to increase '
'the GPU memory size. For example, if you have one 24 GB '
'GPU and set this to 10, virtually you can think of it as '
'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
'which requires at least 26GB GPU memory. Note that this '
'requires fast CPU-GPU interconnect, as part of the model is '
'loaded from CPU memory to GPU memory on the fly in each '
'model forward pass.')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9. This is a per-instance '
'limit, and only applies to the current vLLM instance.'
'It does not matter if you have another vLLM instance running '
'on the same GPU. For example, if you have two vLLM instances '
'running on the same GPU, you can set the GPU memory utilization '
'to 0.5 for each instance.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
' of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills.")
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens.")
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument( parser.add_argument(
'--max-logprobs', '--max-logprobs',
type=int, type=int,
...@@ -598,7 +602,7 @@ class EngineArgs: ...@@ -598,7 +602,7 @@ class EngineArgs:
# Quantization settings. # Quantization settings.
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=nullable_str, type=optional_type(str),
choices=[*QUANTIZATION_METHODS, None], choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
...@@ -649,162 +653,113 @@ class EngineArgs: ...@@ -649,162 +653,113 @@ class EngineArgs:
'Additionally for encoder-decoder models, if the ' 'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger ' 'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.') 'than this, we fall back to the eager mode.')
parser.add_argument('--tokenizer-pool-size',
type=int, # Tokenizer arguments
default=EngineArgs.tokenizer_pool_size, tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
help='Size of tokenizer pool to use for ' tokenizer_group = parser.add_argument_group(
'asynchronous tokenization. If 0, will ' title="TokenizerPoolConfig",
'use synchronous tokenization.') description=TokenizerPoolConfig.__doc__,
parser.add_argument('--tokenizer-pool-type', )
type=str, tokenizer_group.add_argument('--tokenizer-pool-size',
default=EngineArgs.tokenizer_pool_type, **tokenizer_kwargs["pool_size"])
help='Type of tokenizer pool to use for ' tokenizer_group.add_argument('--tokenizer-pool-type',
'asynchronous tokenization. Ignored ' **tokenizer_kwargs["pool_type"])
'if tokenizer_pool_size is 0.') tokenizer_group.add_argument('--tokenizer-pool-extra-config',
parser.add_argument('--tokenizer-pool-extra-config', **tokenizer_kwargs["extra_config"])
type=nullable_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# Multimodal related configs # Multimodal related configs
parser.add_argument( multimodal_kwargs = get_kwargs(MultiModalConfig)
'--limit-mm-per-prompt', multimodal_group = parser.add_argument_group(
type=nullable_kvs, title="MultiModalConfig",
default=EngineArgs.limit_mm_per_prompt, description=MultiModalConfig.__doc__,
# The default value is given in )
# MultiModalConfig.get_default_limit_per_prompt multimodal_group.add_argument('--limit-mm-per-prompt',
help=('For each multimodal plugin, limit how many ' **multimodal_kwargs["limit_per_prompt"])
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to '
'1 (V0) or 999 (V1) for each modality.'))
parser.add_argument( parser.add_argument(
'--mm-processor-kwargs', '--mm-processor-kwargs',
default=None, default=None,
type=json.loads, type=json.loads,
help=('Overrides for the multimodal input mapping/processing, ' help=('Overrides for the multi-modal processor obtained from '
'e.g., image processor. For example: ``{"num_crops": 4}``.')) '``AutoProcessor.from_pretrained``. The available overrides '
'depend on the model that is being run.'
'For example, for Phi-3-Vision: ``{"num_crops": 4}``.'))
parser.add_argument( parser.add_argument(
'--disable-mm-preprocessor-cache', '--disable-mm-preprocessor-cache',
action='store_true', action='store_true',
help='If true, then disables caching of the multi-modal ' help='If True, disable caching of the processed multi-modal '
'preprocessor/mapper. (not recommended)') 'inputs.')
# LoRA related configs # LoRA related configs
parser.add_argument('--enable-lora', lora_kwargs = get_kwargs(LoRAConfig)
action='store_true', lora_group = parser.add_argument_group(
help='If True, enable handling of LoRA adapters.') title="LoRAConfig",
parser.add_argument('--enable-lora-bias', description=LoRAConfig.__doc__,
action='store_true', )
help='If True, enable bias for LoRA adapters.') lora_group.add_argument(
parser.add_argument('--max-loras', '--enable-lora',
type=int, action=argparse.BooleanOptionalAction,
default=EngineArgs.max_loras, help='If True, enable handling of LoRA adapters.')
help='Max number of LoRAs in a single batch.') lora_group.add_argument('--enable-lora-bias',
parser.add_argument('--max-lora-rank', **lora_kwargs["bias_enabled"])
type=int, lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
default=EngineArgs.max_lora_rank, lora_group.add_argument('--max-lora-rank',
help='Max LoRA rank.') **lora_kwargs["max_lora_rank"])
parser.add_argument('--merge-lora', lora_group.add_argument('--merge-lora',
type=bool, action=argparse.BooleanOptionalAction,
default=False,
help='If set to True, the weights of the base layer will be merged with the weights of Lora.') help='If set to True, the weights of the base layer will be merged with the weights of Lora.')
parser.add_argument('--lora-target-modules', lora_group.add_argument('--lora-target-modules',
nargs='*', **lora_kwargs["lora_target_modules"])
default=None, lora_group.add_argument('--lora-extra-vocab-size',
help='List of lora module name, If not specified, modules will be chosen according to the model architecture.') **lora_kwargs["lora_extra_vocab_size"])
parser.add_argument( lora_group.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype', '--lora-dtype',
type=str, **lora_kwargs["lora_dtype"],
default=EngineArgs.lora_dtype, )
choices=['auto', 'float16', 'bfloat16'], lora_group.add_argument('--long-lora-scaling-factors',
help=('Data type for LoRA. If auto, will default to ' **lora_kwargs["long_lora_scaling_factors"])
'base model dtype.')) lora_group.add_argument('--max-cpu-loras',
parser.add_argument( **lora_kwargs["max_cpu_loras"])
'--long-lora-scaling-factors', lora_group.add_argument('--fully-sharded-loras',
type=nullable_str, **lora_kwargs["fully_sharded_loras"])
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can ' # PromptAdapter related configs
'be different from base model scaling factor ' prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
'- see eg. Long LoRA) to allow for multiple ' prompt_adapter_group = parser.add_argument_group(
'LoRA adapters trained with those scaling ' title="PromptAdapterConfig",
'factors to be used at the same time. If not ' description=PromptAdapterConfig.__doc__,
'specified, only adapters trained with the ' )
'base model scaling factor are allowed.')) prompt_adapter_group.add_argument(
parser.add_argument( '--enable-prompt-adapter',
'--max-cpu-loras', action=argparse.BooleanOptionalAction,
type=int, help='If True, enable handling of PromptAdapters.')
default=EngineArgs.max_cpu_loras, prompt_adapter_group.add_argument(
help=('Maximum number of LoRAs to store in CPU memory. ' '--max-prompt-adapters',
'Must be >= than max_loras.')) **prompt_adapter_kwargs["max_prompt_adapters"])
parser.add_argument( prompt_adapter_group.add_argument(
'--fully-sharded-loras', '--max-prompt-adapter-token',
action='store_true', **prompt_adapter_kwargs["max_prompt_adapter_token"])
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. ' # Device arguments
'Enabling this will use the fully sharded layers. ' device_kwargs = get_kwargs(DeviceConfig)
'At high sequence length, max rank or ' device_group = parser.add_argument_group(
'tensor parallel size, this is likely faster.')) title="DeviceConfig",
parser.add_argument('--enable-prompt-adapter', description=DeviceConfig.__doc__,
action='store_true', )
help='If True, enable handling of PromptAdapters.') device_group.add_argument("--device", **device_kwargs["device"])
parser.add_argument('--max-prompt-adapters',
type=int, # Speculative arguments
default=EngineArgs.max_prompt_adapters, speculative_group = parser.add_argument_group(
help='Max number of PromptAdapters in a batch.') title="SpeculativeConfig",
parser.add_argument('--max-prompt-adapter-token', description=SpeculativeConfig.__doc__,
type=int, )
default=EngineArgs.max_prompt_adapter_token, speculative_group.add_argument(
help='Max number of PromptAdapters tokens') '--speculative-config',
parser.add_argument("--device", type=json.loads,
type=str, default=None,
default=EngineArgs.device, help='The configurations for speculative decoding.'
choices=DEVICE_OPTIONS, ' Should be a JSON string.')
help='Device type for vLLM execution.')
parser.add_argument('--num-scheduler-steps',
type=int,
default=1,
help=('Maximum number of forward steps per '
'scheduler call.'))
parser.add_argument(
'--multi-step-stream-outputs',
action=StoreBoolean,
default=EngineArgs.multi_step_stream_outputs,
nargs="?",
const="True",
help='If False, then multi-step will stream outputs at the end '
'of all steps')
parser.add_argument(
'--scheduler-delay-factor',
type=float,
default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous '
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
action=StoreBoolean,
default=EngineArgs.enable_chunked_prefill,
nargs="?",
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=json.loads,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument( parser.add_argument(
'--num-speculative-heads', '--num-speculative-heads',
type=int, type=int,
...@@ -819,13 +774,6 @@ class EngineArgs: ...@@ -819,13 +774,6 @@ class EngineArgs:
help="The pattern(s) to ignore when loading the model." help="The pattern(s) to ignore when loading the model."
"Default to `original/**/*` to avoid repeated loading of llama's " "Default to `original/**/*` to avoid repeated loading of llama's "
"checkpoints.") "checkpoints.")
parser.add_argument(
'--preemption-mode',
type=str,
default=None,
help='If \'recompute\', the engine performs preemption by '
'recomputing; If \'swap\', the engine performs preemption by '
'block swapping.')
parser.add_argument( parser.add_argument(
"--served-model-name", "--served-model-name",
...@@ -881,22 +829,47 @@ class EngineArgs: ...@@ -881,22 +829,47 @@ class EngineArgs:
help="Disable async output processing. This may result in " help="Disable async output processing. This may result in "
"lower performance.") "lower performance.")
parser.add_argument( # Scheduler arguments
'--scheduling-policy', scheduler_kwargs = get_kwargs(SchedulerConfig)
choices=['fcfs', 'priority'], scheduler_group = parser.add_argument_group(
default="fcfs", title="SchedulerConfig",
help='The scheduling policy to use. "fcfs" (first come first served' description=SchedulerConfig.__doc__,
', i.e. requests are handled in order of arrival; default) ' )
'or "priority" (requests are handled based on given ' scheduler_group.add_argument(
'priority (lower value means earlier handling) and time of ' '--max-num-batched-tokens',
'arrival deciding any ties).') **scheduler_kwargs["max_num_batched_tokens"])
scheduler_group.add_argument('--max-num-seqs',
parser.add_argument( **scheduler_kwargs["max_num_seqs"])
'--scheduler-cls', scheduler_group.add_argument(
default=EngineArgs.scheduler_cls, "--max-num-partial-prefills",
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" ' **scheduler_kwargs["max_num_partial_prefills"])
'is the default scheduler. Can be a class directly or the path to ' scheduler_group.add_argument(
'a class of form "mod.custom_class".') "--max-long-partial-prefills",
**scheduler_kwargs["max_long_partial_prefills"])
scheduler_group.add_argument(
"--long-prefill-token-threshold",
**scheduler_kwargs["long_prefill_token_threshold"])
scheduler_group.add_argument('--num-lookahead-slots',
**scheduler_kwargs["num_lookahead_slots"])
scheduler_group.add_argument('--scheduler-delay-factor',
**scheduler_kwargs["delay_factor"])
scheduler_group.add_argument('--preemption-mode',
**scheduler_kwargs["preemption_mode"])
scheduler_group.add_argument('--num-scheduler-steps',
**scheduler_kwargs["num_scheduler_steps"])
scheduler_group.add_argument(
'--multi-step-stream-outputs',
**scheduler_kwargs["multi_step_stream_outputs"])
scheduler_group.add_argument('--scheduling-policy',
**scheduler_kwargs["policy"])
scheduler_group.add_argument(
'--enable-chunked-prefill',
**scheduler_kwargs["enable_chunked_prefill"])
scheduler_group.add_argument(
"--disable-chunked-mm-input",
**scheduler_kwargs["disable_chunked_mm_input"])
parser.add_argument('--scheduler-cls',
**scheduler_kwargs["scheduler_cls"])
parser.add_argument( parser.add_argument(
'--override-neuron-config', '--override-neuron-config',
...@@ -923,10 +896,11 @@ class EngineArgs: ...@@ -923,10 +896,11 @@ class EngineArgs:
'testing only. level 3 is the recommended level ' 'testing only. level 3 is the recommended level '
'for production.\n' 'for production.\n'
'To specify the full compilation config, ' 'To specify the full compilation config, '
'use a JSON string.\n' 'use a JSON string, e.g. ``{"level": 3, '
'"cudagraph_capture_sizes": [1, 2, 4, 8]}``\n'
'Following the convention of traditional ' 'Following the convention of traditional '
'compilers, using -O without space is also ' 'compilers, using ``-O`` without space is also '
'supported. -O3 is equivalent to -O 3.') 'supported. ``-O3`` is equivalent to ``-O 3``.')
parser.add_argument('--kv-transfer-config', parser.add_argument('--kv-transfer-config',
type=KVTransferConfig.from_cli, type=KVTransferConfig.from_cli,
...@@ -948,7 +922,7 @@ class EngineArgs: ...@@ -948,7 +922,7 @@ class EngineArgs:
'class without changing the existing functions.') 'class without changing the existing functions.')
parser.add_argument( parser.add_argument(
"--generation-config", "--generation-config",
type=nullable_str, type=optional_type(str),
default="auto", default="auto",
help="The folder path to the generation config. " help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from " "Defaults to 'auto', the generation config will be loaded from "
...@@ -975,15 +949,6 @@ class EngineArgs: ...@@ -975,15 +949,6 @@ class EngineArgs:
help="Enable sleep mode for the engine. " help="Enable sleep mode for the engine. "
"(only cuda platform is supported)") "(only cuda platform is supported)")
parser.add_argument(
'--calculate-kv-scales',
action='store_true',
help='This enables dynamic calculation of '
'k_scale and v_scale when kv-cache-dtype is fp8. '
'If calculate-kv-scales is false, the scales will '
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
parser.add_argument( parser.add_argument(
"--additional-config", "--additional-config",
type=json.loads, type=json.loads,
...@@ -1001,16 +966,6 @@ class EngineArgs: ...@@ -1001,16 +966,6 @@ class EngineArgs:
"If enabled, the model will be able to generate reasoning content." "If enabled, the model will be able to generate reasoning content."
) )
parser.add_argument(
"--reasoning-parser",
type=str,
choices=list(ReasoningParserManager.reasoning_parsers),
default=None,
help=
"Select the reasoning parser depending on the model that you're "
"using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``.")
parser.add_argument( parser.add_argument(
"--disable-cascade-attn", "--disable-cascade-attn",
action="store_true", action="store_true",
...@@ -1021,20 +976,6 @@ class EngineArgs: ...@@ -1021,20 +976,6 @@ class EngineArgs:
"Note that even if this is set to False, cascade attention will be " "Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.") "only used when the heuristic tells that it's beneficial.")
parser.add_argument(
"--disable-chunked-mm-input",
action=StoreBoolean,
default=EngineArgs.disable_chunked_mm_input,
nargs="?",
const="True",
help="Disable multimodal input chunking attention for V1. "
"If set to true and chunked prefill is enabled, we do not want to"
" partially schedule a multimodal item. This ensures that if a "
"request has a mixed prompt (like text tokens TTTT followed by "
"image tokens IIIIIIIIII) where only some image tokens can be "
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
"as TTTT in one step and IIIIIIIIII in the next.")
return parser return parser
@classmethod @classmethod
...@@ -1228,11 +1169,6 @@ class EngineArgs: ...@@ -1228,11 +1169,6 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers, max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce, disable_custom_all_reduce=self.disable_custom_all_reduce,
tokenizer_pool_config=TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight, ray_workers_use_nsight=self.ray_workers_use_nsight,
placement_group=placement_group, placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend, distributed_executor_backend=self.distributed_executor_backend,
...@@ -1308,8 +1244,6 @@ class EngineArgs: ...@@ -1308,8 +1244,6 @@ class EngineArgs:
if self.qlora_adapter_name_or_path is not None and \ if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "": self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[ self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
...@@ -1390,7 +1324,7 @@ class EngineArgs: ...@@ -1390,7 +1324,7 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
if self.preemption_mode != EngineArgs.preemption_mode: if self.preemption_mode != SchedulerConfig.preemption_mode:
_raise_or_fallback(feature_name="--preemption-mode", _raise_or_fallback(feature_name="--preemption-mode",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
...@@ -1401,34 +1335,28 @@ class EngineArgs: ...@@ -1401,34 +1335,28 @@ class EngineArgs:
recommend_to_remove=True) recommend_to_remove=True)
return False return False
if self.scheduling_policy != EngineArgs.scheduling_policy: if self.scheduling_policy != SchedulerConfig.policy:
_raise_or_fallback(feature_name="--scheduling-policy", _raise_or_fallback(feature_name="--scheduling-policy",
recommend_to_remove=False) recommend_to_remove=False)
return False return False
if self.num_scheduler_steps != EngineArgs.num_scheduler_steps: if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
_raise_or_fallback(feature_name="--num-scheduler-steps", _raise_or_fallback(feature_name="--num-scheduler-steps",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor: if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
_raise_or_fallback(feature_name="--scheduler-delay-factor", _raise_or_fallback(feature_name="--scheduler-delay-factor",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
if self.additional_config != EngineArgs.additional_config: # remove backend options when doing this check
_raise_or_fallback(feature_name="--additional-config", if self.guided_decoding_backend.split(':')[0] \
recommend_to_remove=False) not in get_args(GuidedDecodingBackendV1):
return False _raise_or_fallback(
feature_name=
# Xgrammar and Guidance are supported. f"--guided-decoding-backend={self.guided_decoding_backend}",
SUPPORTED_GUIDED_DECODING = [ recommend_to_remove=False)
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend",
recommend_to_remove=False)
return False return False
# Need at least Ampere for now (FA support required). # Need at least Ampere for now (FA support required).
...@@ -1452,7 +1380,7 @@ class EngineArgs: ...@@ -1452,7 +1380,7 @@ class EngineArgs:
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False supported = False
if fp8_attention and will_use_fa: if fp8_attention and will_use_fa:
from vllm.vllm_flash_attn.fa_utils import ( from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8) flash_attn_supports_fp8)
supported = flash_attn_supports_fp8() supported = flash_attn_supports_fp8()
if not supported: if not supported:
...@@ -1495,9 +1423,9 @@ class EngineArgs: ...@@ -1495,9 +1423,9 @@ class EngineArgs:
# No Concurrent Partial Prefills so far. # No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills if (self.max_num_partial_prefills
!= EngineArgs.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
or self.max_long_partial_prefills or self.max_long_partial_prefills
!= EngineArgs.max_long_partial_prefills): != SchedulerConfig.max_long_partial_prefills):
_raise_or_fallback(feature_name="Concurrent Partial Prefill", _raise_or_fallback(feature_name="Concurrent Partial Prefill",
recommend_to_remove=False) recommend_to_remove=False)
return False return False
...@@ -1517,7 +1445,7 @@ class EngineArgs: ...@@ -1517,7 +1445,7 @@ class EngineArgs:
if speculative_method: if speculative_method:
if speculative_method in ("ngram", "[ngram]"): if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True is_ngram_enabled = True
elif speculative_method == "eagle": elif speculative_method in ("eagle", "eagle3"):
is_eagle_enabled = True is_eagle_enabled = True
else: else:
speculative_model = self.speculative_config.get("model") speculative_model = self.speculative_config.get("model")
...@@ -1529,16 +1457,17 @@ class EngineArgs: ...@@ -1529,16 +1457,17 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No Disaggregated Prefill so far. # No XFormers so far.
if self.kv_transfer_config != EngineArgs.kv_transfer_config:
_raise_or_fallback(feature_name="--kv-transfer-config",
recommend_to_remove=False)
return False
# No FlashInfer or XFormers so far.
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", "FLASH_ATTN_VLLM_V1",
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" "FLASH_ATTN",
"PALLAS",
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
...@@ -1640,9 +1569,7 @@ class EngineArgs: ...@@ -1640,9 +1569,7 @@ class EngineArgs:
self.enable_prefix_caching = False self.enable_prefix_caching = False
# VLLM_V0 only supports builtin hash algo for prefix caching. # VLLM_V0 only supports builtin hash algo for prefix caching.
if self.prefix_caching_hash_algo is None: if self.prefix_caching_hash_algo == "sha256":
self.prefix_caching_hash_algo = "builtin"
elif self.prefix_caching_hash_algo == "sha256":
raise ValueError( raise ValueError(
"sha256 is not supported for prefix caching in V0 engine. " "sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'.") "Please use 'builtin'.")
...@@ -1661,10 +1588,6 @@ class EngineArgs: ...@@ -1661,10 +1588,6 @@ class EngineArgs:
if self.enable_prefix_caching is None: if self.enable_prefix_caching is None:
self.enable_prefix_caching = True self.enable_prefix_caching = True
# if using prefix caching, we must set a hash algo
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
self.prefix_caching_hash_algo = "builtin"
# V1 should use the new scheduler by default. # V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default # Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls: if self.scheduler_cls == EngineArgs.scheduler_cls:
...@@ -1681,13 +1604,13 @@ class EngineArgs: ...@@ -1681,13 +1604,13 @@ class EngineArgs:
# values for non-H100/H200 GPUs. # values for non-H100/H200 GPUs.
try: try:
from vllm.platforms import current_platform from vllm.platforms import current_platform
device_name = current_platform.get_device_name().lower() device_memory = current_platform.get_device_total_memory()
except Exception: except Exception:
# This is only used to set default_max_num_batched_tokens # This is only used to set default_max_num_batched_tokens
device_name = "no-device" device_memory = 0
if "h100" in device_name or "h200" in device_name: if device_memory >= 70 * GiB_bytes:
# For H100 and H200, we use larger default values. # For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = { default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 16384, UsageContext.LLM_CLASS: 16384,
UsageContext.OPENAI_API_SERVER: 8192, UsageContext.OPENAI_API_SERVER: 8192,
......
...@@ -493,12 +493,11 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -493,12 +493,11 @@ class _AsyncLLMEngine(LLMEngine):
tokenizer = await self.get_tokenizer_async(lora_request) tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer) self._validate_token_prompt(prompt, tokenizer=tokenizer)
preprocessed_inputs = await self.input_preprocessor.preprocess_async( processed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs)
if isinstance(params, SamplingParams) and \ if isinstance(params, SamplingParams) and \
params.guided_decoding is not None: params.guided_decoding is not None:
...@@ -526,10 +525,15 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -526,10 +525,15 @@ class _AsyncLLMEngine(LLMEngine):
) )
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()
async def collective_rpc_async(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
raise NotImplementedError
async def build_guided_decoding_logits_processor_async( async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer, sampling_params: SamplingParams, tokenizer: AnyTokenizer,
...@@ -1167,6 +1171,10 @@ class AsyncLLMEngine(EngineClient): ...@@ -1167,6 +1171,10 @@ class AsyncLLMEngine(EngineClient):
exception=asyncio.CancelledError, exception=asyncio.CancelledError,
verbose=self.log_requests) verbose=self.log_requests)
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
return self.engine.get_vllm_config()
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
return self.engine.get_model_config() return self.engine.get_model_config()
...@@ -1234,6 +1242,17 @@ class AsyncLLMEngine(EngineClient): ...@@ -1234,6 +1242,17 @@ class AsyncLLMEngine(EngineClient):
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request) self.engine.add_lora(lora_request)
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
"""
Perform a collective RPC call to the given path.
"""
return await self.engine.collective_rpc_async(method, timeout, args,
kwargs)
# TODO(v1): Remove this class proxy when V1 goes default. # TODO(v1): Remove this class proxy when V1 goes default.
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
......
...@@ -30,8 +30,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group ...@@ -30,8 +30,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import ( from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors) get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
PromptType, SingletonInputs)
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -56,7 +55,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, ...@@ -56,7 +55,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import (Counter, Device, deprecate_kwargs, from vllm.utils import (Counter, Device, deprecate_kwargs,
...@@ -67,7 +66,6 @@ from vllm.worker.model_runner_base import InputProcessingError ...@@ -67,7 +66,6 @@ from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
_R = TypeVar("_R", default=Any) _R = TypeVar("_R", default=Any)
...@@ -206,7 +204,7 @@ class LLMEngine: ...@@ -206,7 +204,7 @@ class LLMEngine:
return outputs_ return outputs_
tokenizer: Optional[BaseTokenizerGroup] tokenizer: Optional[TokenizerGroup]
def __init__( def __init__(
self, self,
...@@ -215,7 +213,6 @@ class LLMEngine: ...@@ -215,7 +213,6 @@ class LLMEngine:
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
) -> None: ) -> None:
...@@ -276,11 +273,7 @@ class LLMEngine: ...@@ -276,11 +273,7 @@ class LLMEngine:
self.tokenizer, self.tokenizer,
mm_registry) mm_registry)
self.input_registry = input_registry self.model_executor = executor_class(vllm_config=vllm_config)
self.input_processor = input_registry.create_input_processor(
self.model_config)
self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.runner_type != "pooling": if self.model_config.runner_type != "pooling":
self._initialize_kv_caches() self._initialize_kv_caches()
...@@ -322,11 +315,6 @@ class LLMEngine: ...@@ -322,11 +315,6 @@ class LLMEngine:
self.parallel_config.disable_custom_all_reduce, self.parallel_config.disable_custom_all_reduce,
}) })
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [ self.cached_scheduler_outputs = [
SchedulerOutputState() SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size) for _ in range(self.parallel_config.pipeline_parallel_size)
...@@ -540,21 +528,12 @@ class LLMEngine: ...@@ -540,21 +528,12 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None): if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown() model_executor.shutdown()
def get_tokenizer_group( def get_tokenizer_group(self) -> TokenizerGroup:
self, if self.tokenizer is None:
group_type: Type[_G] = BaseTokenizerGroup,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
raise ValueError("Unable to get tokenizer because " raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True") "skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group return self.tokenizer
def get_tokenizer( def get_tokenizer(
self, self,
...@@ -562,11 +541,10 @@ class LLMEngine: ...@@ -562,11 +541,10 @@ class LLMEngine:
) -> AnyTokenizer: ) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request) return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> BaseTokenizerGroup: def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs( return init_tokenizer_from_configs(
model_config=self.model_config, model_config=self.model_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
lora_config=self.lora_config) lora_config=self.lora_config)
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -781,12 +759,11 @@ class LLMEngine: ...@@ -781,12 +759,11 @@ class LLMEngine:
prompt, prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request)) tokenizer=self.get_tokenizer(lora_request=lora_request))
preprocessed_inputs = self.input_preprocessor.preprocess( processed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -917,6 +894,10 @@ class LLMEngine: ...@@ -917,6 +894,10 @@ class LLMEngine:
scheduler.abort_seq_group( scheduler.abort_seq_group(
request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
def get_vllm_config(self) -> VllmConfig:
"""Gets the vllm configuration."""
return self.vllm_config
def get_model_config(self) -> ModelConfig: def get_model_config(self) -> ModelConfig:
"""Gets the model configuration.""" """Gets the model configuration."""
return self.model_config return self.model_config
...@@ -1965,8 +1946,6 @@ class LLMEngine: ...@@ -1965,8 +1946,6 @@ class LLMEngine:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping
def check_health(self) -> None: def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()
def is_tracing_enabled(self) -> bool: def is_tracing_enabled(self) -> bool:
...@@ -2075,7 +2054,7 @@ class LLMEngine: ...@@ -2075,7 +2054,7 @@ class LLMEngine:
raise ValueError(f"The {prompt_type} prompt cannot be empty") raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_prompt_len = self.model_config.max_model_len max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len: if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model: if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor( mm_processor = mm_registry.create_processor(
......
...@@ -140,16 +140,13 @@ class Metrics: ...@@ -140,16 +140,13 @@ class Metrics:
name="vllm:generation_tokens_total", name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames) labelnames=labelnames)
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy()
buckets.sort()
self.histogram_iteration_tokens = self._histogram_cls( self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total", name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.", documentation="Histogram of number of tokens per engine_step.",
labelnames=labelnames, labelnames=labelnames,
buckets=buckets) buckets=[
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384
])
self.histogram_time_to_first_token = self._histogram_cls( self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds", name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
......
...@@ -93,6 +93,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -93,6 +93,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Get the configs. # Get the configs.
self.vllm_config = engine_config
self.model_config = engine_config.model_config self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config self.decoding_config = engine_config.decoding_config
...@@ -100,7 +101,6 @@ class MQLLMEngineClient(EngineClient): ...@@ -100,7 +101,6 @@ class MQLLMEngineClient(EngineClient):
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config, model_config=self.model_config,
scheduler_config=engine_config.scheduler_config, scheduler_config=engine_config.scheduler_config,
parallel_config=engine_config.parallel_config,
lora_config=engine_config.lora_config) lora_config=engine_config.lora_config)
self.input_preprocessor = InputPreprocessor(self.model_config, self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer) self.tokenizer)
...@@ -377,6 +377,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -377,6 +377,9 @@ class MQLLMEngineClient(EngineClient):
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request) return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config return self.decoding_config
......
...@@ -178,7 +178,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -178,7 +178,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# generates a fixed number of tokens without evaluating stopping # generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be # conditions within the block. This can cause an eos token to be
# unintentionally ignored. # unintentionally ignored.
if not sampling_params.ignore_eos: if not sampling_params.ignore_eos and self.detokenizer:
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
# Avoiding .index calls as exception throwing in the happy path # Avoiding .index calls as exception throwing in the happy path
# is expensive. # is expensive.
......
...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod ...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional from typing import AsyncGenerator, List, Mapping, Optional
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
...@@ -220,6 +220,11 @@ class EngineClient(ABC): ...@@ -220,6 +220,11 @@ class EngineClient(ABC):
""" """
... ...
@abstractmethod
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
...
@abstractmethod @abstractmethod
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
......
...@@ -27,10 +27,11 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam, ...@@ -27,10 +27,11 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam) ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import ( from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio) InputAudio)
from pydantic import TypeAdapter
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin) ProcessorMixin)
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypeAlias, TypedDict from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -482,11 +483,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -482,11 +483,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if modality in ("image", "image_embeds"): if modality in ("image", "image_embeds"):
if model_type == "chatglm": if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>" return "<|begin_of_image|><|endoftext|><|end_of_image|>"
if model_type == "phi3_v": if model_type in ("phi3_v", "phi4mm"):
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>" return f"<|image_{current_count}|>"
if model_type == "phi4mm":
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"): if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)" return "(<image>./</image>)"
if model_type in ("blip-2", "florence2", "fuyu", "paligemma", if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
...@@ -506,20 +504,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -506,20 +504,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|image|>" return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"): if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>" return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|IMAGE|><|vision_end|>"
if model_type == "molmo": if model_type == "molmo":
return "" return ""
if model_type == "aria": if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>" return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3": if model_type == "gemma3":
return "<start_of_image>" return "<start_of_image>"
if model_type == "kimi_vl":
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
raise TypeError(f"Unknown {modality} model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type in ("ultravox", "granite_speech"):
return "<|audio|>" return "<|audio|>"
if model_type == "phi4mm": if model_type == "phi4mm":
return "<|endoftext11|>" # 200011 (see vocab.json in hf model) return f"<|audio_{current_count}|>"
if model_type == "qwen2_audio": if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: " return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>") f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "minicpmo": if model_type == "minicpmo":
...@@ -528,6 +530,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -528,6 +530,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "video": elif modality == "video":
if model_type in ("qwen2_vl", "qwen2_5_vl"): if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>" return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|VIDEO|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"): if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)" return "(<video>./</video>)"
if model_type.startswith("llava"): if model_type.startswith("llava"):
...@@ -876,12 +880,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], ...@@ -876,12 +880,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again # No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam) _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) # Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio] _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
...@@ -1092,7 +1097,11 @@ def _parse_chat_message_content( ...@@ -1092,7 +1097,11 @@ def _parse_chat_message_content(
if role == 'assistant': if role == 'assistant':
parsed_msg = _AssistantParser(message) parsed_msg = _AssistantParser(message)
if "tool_calls" in parsed_msg: # The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly
# follow the OpenAI spec.
if ("tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None):
result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool": elif role == "tool":
parsed_msg = _ToolParser(message) parsed_msg = _ToolParser(message)
...@@ -1189,14 +1198,25 @@ def apply_hf_chat_template( ...@@ -1189,14 +1198,25 @@ def apply_hf_chat_template(
"allowed, so you must provide a chat template if the tokenizer " "allowed, so you must provide a chat template if the tokenizer "
"does not define one.") "does not define one.")
return tokenizer.apply_chat_template( try:
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type] return tokenizer.apply_chat_template(
chat_template=hf_chat_template, conversation=conversation, # type: ignore[arg-type]
tokenize=tokenize, tools=tools, # type: ignore[arg-type]
**kwargs, chat_template=hf_chat_template,
) tokenize=tokenize,
**kwargs,
)
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template")
raise ValueError from e
def apply_mistral_chat_template( def apply_mistral_chat_template(
tokenizer: MistralTokenizer, tokenizer: MistralTokenizer,
...@@ -1205,6 +1225,8 @@ def apply_mistral_chat_template( ...@@ -1205,6 +1225,8 @@ def apply_mistral_chat_template(
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
**kwargs: Any, **kwargs: Any,
) -> list[int]: ) -> list[int]:
from mistral_common.exceptions import MistralCommonException
# The return value of resolve_mistral_chat_template is always None, # The return value of resolve_mistral_chat_template is always None,
# and we won't use it. # and we won't use it.
resolve_mistral_chat_template( resolve_mistral_chat_template(
...@@ -1222,5 +1244,16 @@ def apply_mistral_chat_template( ...@@ -1222,5 +1244,16 @@ def apply_mistral_chat_template(
# if input does not comply with the expected format. # if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be # We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step # are properly caught in the preprocessing_input step
except AssertionError as e: except (AssertionError, MistralCommonException) as e:
raise ValueError from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat "
"template")
raise ValueError from e raise ValueError from e
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.benchmarks.latency import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
""" The `latency` subcommand for vllm bench. """
def __init__(self):
self.name = "latency"
super().__init__()
@property
def help(self) -> str:
return "Benchmark the latency of a single batch of requests."
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkLatencySubcommand()]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import vllm.entrypoints.cli.benchmark.latency
import vllm.entrypoints.cli.benchmark.serve import vllm.entrypoints.cli.benchmark.serve
import vllm.entrypoints.cli.benchmark.throughput
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
# TODO: Add the rest of the benchmark subcommands here,
# e.g., throughput, latency, etc.
BENCHMARK_CMD_MODULES = [ BENCHMARK_CMD_MODULES = [
vllm.entrypoints.cli.benchmark.latency,
vllm.entrypoints.cli.benchmark.serve, vllm.entrypoints.cli.benchmark.serve,
vllm.entrypoints.cli.benchmark.throughput,
] ]
......
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.benchmarks.throughput import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
""" The `throughput` subcommand for vllm bench. """
def __init__(self):
self.name = "throughput"
super().__init__()
@property
def help(self) -> str:
return "Benchmark offline inference throughput."
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkThroughputSubcommand()]
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.collect_env import main as collect_env_main
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser
class CollectEnvSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "collect-env"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
"""Collect information about the environment."""
collect_env_main()
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
"collect-env",
help="Start collecting environment information.",
description="Start collecting environment information.",
usage="vllm collect-env")
return make_arg_parser(serve_parser)
def cmd_init() -> list[CLISubcommand]:
return [CollectEnvSubcommand()]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment