Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
......@@ -47,7 +47,10 @@ class MultiConnector(KVConnectorBase_V1):
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
engine_id = ktc.get("engine_id",
vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id)
self._connectors.append(
KVConnectorFactory.create_connector_v1(temp_config, role))
......@@ -187,7 +190,7 @@ class MultiConnector(KVConnectorBase_V1):
async_saves += 1
if txfer_params is not None:
if kv_txfer_params is not None:
#TODO we can probably change this to merge the dicts here,
# TODO we can probably change this to merge the dicts here,
# checking for key clashes.
raise RuntimeError(
"Only one connector can produce KV transfer params")
......
......@@ -79,7 +79,8 @@ class ReqMeta:
class NixlConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[ReqId, ReqMeta] = {}
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
def add_new_req(
self,
......@@ -87,7 +88,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
):
self.requests[request_id] = ReqMeta(
self.reqs_to_recv[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
......@@ -194,10 +195,12 @@ class NixlConnectorScheduler:
vllm_config.parallel_config.tensor_parallel_size)
logger.info("Initializing NIXL Scheduler %s", engine_id)
# Requests that need to start recv.
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
def get_num_new_matched_tokens(
self, request: "Request",
......@@ -284,6 +287,9 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
meta.reqs_to_send = self._reqs_need_send
self._reqs_need_send = {}
return meta
def request_finished(
......@@ -325,6 +331,11 @@ class NixlConnectorScheduler:
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks = len(computed_block_ids) > 0
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send[request.request_id] = time.perf_counter(
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
......@@ -394,14 +405,8 @@ class NixlConnectorWorker:
# In progress transfers.
# [req_id -> list[handle]]
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self._done_recving_count: defaultdict[ReqId,
int] = defaultdict(lambda: 0)
self._done_sending_count: defaultdict[ReqId,
int] = defaultdict(lambda: 0)
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
......@@ -475,8 +480,13 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
def _nixl_handshake(
self,
host: str,
port: int,
remote_tp_size: int,
expected_engine_id: str,
) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
......@@ -485,26 +495,6 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
def handshake(path: str, rank: int) -> str:
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, rank, remote_tp_size)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
return remote_agent_name
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
......@@ -512,8 +502,32 @@ class NixlConnectorWorker:
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s", path,
p_remote_rank)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}.")
# Register Remote agent.
remote_agent_name = self.add_remote_agent(metadata, p_remote_rank,
remote_tp_size)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
# Remote rank -> agent name.
return {p_remote_rank: handshake(path, p_remote_rank)}
return {p_remote_rank: remote_agent_name}
def _background_nixl_handshake(self, req_id: str,
remote_engine_id: EngineId, meta: ReqMeta):
......@@ -522,7 +536,7 @@ class NixlConnectorWorker:
if fut is None:
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host, meta.remote_port,
meta.tp_size)
meta.tp_size, remote_engine_id)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
......@@ -725,10 +739,10 @@ class NixlConnectorWorker:
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
return self._remote_agents[engine_id][remote_tp_rank]
if engine_id in self._tp_size:
assert self._tp_size[engine_id] == remote_tp_size
else:
if engine_id not in self._tp_size:
self._tp_size[engine_id] = remote_tp_size
else:
assert self._tp_size[engine_id] == remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name
......@@ -808,15 +822,9 @@ class NixlConnectorWorker:
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving.
In TP>1 setup, each rank exchanges KVs with its counterpart
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
are done before adding to finished, Ranks 1 to N-1 communicate
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
......@@ -826,50 +834,17 @@ class NixlConnectorWorker:
"and %s requests done recving", self.tp_rank,
len(done_sending), len(done_recving))
if self.world_size == 1:
return done_sending, done_recving
# Rank 0: get finished from all other ranks.
if self.tp_rank == 0:
for req_id in done_sending:
self._done_sending_count[req_id] += 1
for req_id in done_recving:
self._done_recving_count[req_id] += 1
# Keep track of how many other ranks have finished.
other_ranks_finished_ids: list[str] = []
for i in range(1, self.world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
if (req_id in self._done_recving_count
or req_id in self._recving_transfers):
self._done_recving_count[req_id] += 1
else:
self._done_sending_count[req_id] += 1
# Return ids that finished on all ranks to the scheduler.
all_done_recving: set[str] = set()
for req_id in list(self._done_recving_count.keys()):
if self._done_recving_count[req_id] == self.world_size:
del self._done_recving_count[req_id]
all_done_recving.add(req_id)
all_done_sending: set[str] = set()
for req_id in list(self._done_sending_count.keys()):
if self._done_sending_count[req_id] == self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
while self._reqs_to_send:
req_id, expires = next(iter(self._reqs_to_send.items()))
# Sorted dict, oldest requests are put first so we can exit early.
if now < expires:
break
del self._reqs_to_send[req_id]
done_sending.add(req_id)
return all_done_sending, all_done_recving
# Ranks 1 to N-1: send finished ids to Rank 0.
else:
finished_req_ids = list(done_recving.union(done_sending))
self.tp_group.send_object(finished_req_ids, dst=0)
# Unused as only Rank 0 results are sent to scheduler.
return done_sending, done_recving
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]:
"""
......@@ -887,6 +862,7 @@ class NixlConnectorWorker:
tp_ratio):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
del self._reqs_to_send[req_id]
return notified_req_ids
def _pop_done_transfers(
......@@ -921,7 +897,7 @@ class NixlConnectorWorker:
Start loading by triggering non-blocking nixl_xfer.
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.requests.items():
for req_id, meta in metadata.reqs_to_recv.items():
remote_engine_id = meta.remote_engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
......@@ -943,6 +919,9 @@ class NixlConnectorWorker:
while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait())
# Add to requests that are waiting to be read and track expiration.
self._reqs_to_send.update(metadata.reqs_to_send)
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
......
......@@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine)
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(
request_id + "#" + layer_name, kv_layer, remote_address,
request.slot_mapping,
isinstance(attn_metadata, MLACommonMetadata))
def wait_for_save(self):
if self.is_producer:
......@@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None
forward_context: ForwardContext = get_forward_context()
no_compile_layers = (
self._vllm_config.compilation_config.static_forward_context)
return self.p2p_nccl_engine.get_finished(finished_req_ids,
forward_context)
no_compile_layers)
# ==============================
# Scheduler-side methods
......@@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_ids=block_ids,
block_size=self._block_size)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self._requests_need_load.clear()
return meta
......
......@@ -8,7 +8,8 @@ import time
import typing
from collections import deque
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass
from typing import Any, Optional
import msgpack
import torch
......@@ -21,9 +22,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
TensorMemoryPool)
from vllm.utils import current_stream, get_ip
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32
......@@ -59,6 +57,15 @@ def set_p2p_nccl_context(num_channels: str):
os.environ.pop(var, None)
@dataclass
class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
slot_mapping: torch.Tensor
is_mla: bool
class P2pNcclEngine:
def __init__(self,
......@@ -112,24 +119,26 @@ class P2pNcclEngine:
self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
1024**3) # GB
mem_pool_size_gb = float(
self.config.get_from_extra_config("mem_pool_size_gb",
DEFAULT_MEM_POOL_SIZE_GB))
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
1024**3)) # GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config("send_type", "PUT")
self.send_type = self.config.get_from_extra_config(
"send_type", "PUT_ASYNC")
if self.send_type == "GET":
# tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {}
else:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque()
self.send_queue: deque[SendQueueItem] = deque()
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async,
self._send_thread = threading.Thread(target=self.send_async,
daemon=True)
self._send_thread.start()
......@@ -146,13 +155,12 @@ class P2pNcclEngine:
"nccl_num_channels", "8")
self._listener_thread = threading.Thread(
target=self._listen_for_requests, daemon=True)
target=self.listen_for_requests, daemon=True)
self._listener_thread.start()
self._ping_thread = None
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping,
daemon=True)
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
self._ping_thread.start()
logger.info(
......@@ -162,7 +170,7 @@ class P2pNcclEngine:
self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect(self, remote_address: typing.Optional[str] = None):
def create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
......@@ -184,7 +192,7 @@ class P2pNcclEngine:
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address]
......@@ -194,44 +202,54 @@ class P2pNcclEngine:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
slot_mapping: torch.Tensor = None,
is_mla: bool = False,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address,
tensor=tensor,
slot_mapping=slot_mapping,
is_mla=is_mla)
if self.send_type == "PUT":
return self.send_sync(item)
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
tensor_id, tensor_size, tensor.shape, self.rank,
self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def recv_tensor(
......@@ -267,7 +285,7 @@ class P2pNcclEngine:
return None
if remote_address not in self.socks:
self._create_connect(remote_address)
self.create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
......@@ -282,121 +300,121 @@ class P2pNcclEngine:
remote_address, tensor_id, data["ret"])
return None
tensor = torch.empty(data["shape"],
dtype=getattr(torch, data["dtype"]),
device=self.device)
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(torch, data["dtype"]),
device=self.device)
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor
def _listen_for_requests(self):
def listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart(
[remote_address, b"1"])
tensor = None
if self.router_socket not in socks:
continue
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(),
rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart([remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address,
remote_address.decode(), data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype":
str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self._have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self._send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart([remote_address, b"1"])
tensor = None
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address, remote_address.decode(),
data)
def _have_sent_tensor_id(self, tensor_id: str):
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self.have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self.send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
def have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def _have_received_tensor_id(self, tensor_id: str):
def have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def _send_async(self):
def send_async(self):
while True:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft()
item = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
self._send_sync(tensor_id, tensor, remote_address)
self.send_sync(item)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
......@@ -409,22 +427,21 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def _send_sync(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
def send_sync(self, item: SendQueueItem) -> bool:
if item.remote_address is None:
return False
if remote_address not in self.socks:
self._create_connect(remote_address)
if item.remote_address not in self.socks:
self.create_connect(item.remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
with self.send_stream:
tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
item.slot_mapping)
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
data = {
"cmd": "PUT",
"tensor_id": tensor_id,
"tensor_id": item.tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
......@@ -435,20 +452,21 @@ class P2pNcclEngine:
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address, rank, data, tensor.shape,
self.zmq_address, item.remote_address, rank, data,
tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
self.have_sent_tensor_id(item.tensor_id)
return True
def get_finished(
self, finished_req_ids: set[str], forward_context: "ForwardContext"
self, finished_req_ids: set[str], no_compile_layers
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
......@@ -463,7 +481,7 @@ class P2pNcclEngine:
# Clear the buffer upon request completion.
for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers:
for layer_name in no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
......@@ -472,7 +490,6 @@ class P2pNcclEngine:
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
......@@ -485,7 +502,7 @@ class P2pNcclEngine:
return finished_sending or None, finished_recving or None
def _ping(self):
def ping(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
......@@ -499,7 +516,7 @@ class P2pNcclEngine:
sock.send(msgpack.dumps(data))
time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
......@@ -512,7 +529,7 @@ class P2pNcclEngine:
comm, cudaStream_t(stream.cuda_stream))
stream.synchronize()
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
......@@ -531,3 +548,21 @@ class P2pNcclEngine:
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
@staticmethod
def extract_kv_from_layer(
is_mla: bool,
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if is_mla:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
......@@ -240,6 +240,8 @@ class GroupCoordinator:
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif current_platform.is_xpu():
self.device = torch.device(f"xpu:{local_rank}")
elif current_platform.is_out_of_tree():
self.device = torch.device(
f"{current_platform.device_name}:{local_rank}")
......@@ -270,6 +272,9 @@ class GroupCoordinator:
self.use_custom_op_call = (current_platform.is_cuda_alike()
or current_platform.is_tpu())
self.use_cpu_custom_send_recv = (current_platform.is_cpu() and hasattr(
torch.ops._C, "init_shm_manager"))
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
......@@ -381,6 +386,12 @@ class GroupCoordinator:
dim: int) -> torch.Tensor:
return self.device_communicator.all_gather(input_, dim)
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
return self.device_communicator.all_gatherv(input_, dim, sizes)
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
......@@ -399,6 +410,12 @@ class GroupCoordinator:
else:
return self._reduce_scatter_out_place(input_, dim)
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
def _reduce_scatter_out_place(self, input_: torch.Tensor,
dim: int) -> torch.Tensor:
return self.device_communicator.reduce_scatter(input_, dim)
......@@ -649,6 +666,11 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst)
return None
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
......@@ -704,6 +726,10 @@ class GroupCoordinator:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
return self.device_communicator.recv_tensor_dict( # type: ignore
src)
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
for key, value in recv_metadata_list:
......@@ -1318,13 +1344,13 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
def is_global_first_rank() -> bool:
"""
Check if the current process is the first rank globally across all
Check if the current process is the first rank globally across all
parallelism strategies (PP, TP, DP, EP, etc.).
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
or `get_pp_group().is_first_rank`, this function checks the global rank
across all parallelism dimensions.
Returns:
bool: True if this is the global first rank (rank 0), False otherwise.
Returns True if distributed is not initialized (single process).
......@@ -1353,7 +1379,7 @@ def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
Args:
pg: The process group to analyze
Returns:
int: The total number of nodes
"""
......
......@@ -10,16 +10,16 @@ import functools
import json
import sys
import threading
import warnings
from dataclasses import MISSING, dataclass, fields, is_dataclass
from itertools import permutations
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
Type, TypeVar, Union, cast, get_args, get_origin)
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List,
Literal, Optional, Type, TypeVar, Union, cast, get_args,
get_origin)
import regex as re
import torch
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeIs, deprecated
from typing_extensions import TypeIs
import vllm.envs as envs
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
......@@ -27,26 +27,33 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, GuidedDecodingBackend,
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
KVTransferConfig, LoadConfig, LoadFormat,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, TokenizerPoolConfig,
VllmConfig, get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
# yapf: enable
if TYPE_CHECKING:
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.usage.usage_lib import UsageContext
else:
ExecutorBase = Any
QuantizationMethods = Any
UsageContext = Any
logger = init_logger(__name__)
# object is used to allow for special typing forms
......@@ -59,8 +66,6 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
def _parse_type(val: str) -> T:
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
......@@ -81,47 +86,11 @@ def optional_type(
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val):
if not re.match(r"(?s)^\s*{.*}\s*$", val):
return str(val)
return optional_type(json.loads)(val)
@deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
"string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
val: String value to be parsed.
Returns:
Dictionary with parsed values.
"""
out_dict: dict[str, int] = {}
for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
raise argparse.ArgumentTypeError(
"Each item should be in the form KEY=VALUE")
key, value = kv_parts
try:
parsed_value = int(value)
except ValueError as exc:
msg = f"Failed to parse value of item {key}={value}"
raise argparse.ArgumentTypeError(msg) from exc
if key in out_dict and out_dict[key] != parsed_value:
raise argparse.ArgumentTypeError(
f"Conflicting values specified for key: {key}")
out_dict[key] = parsed_value
return out_dict
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the type hint is a specific type."""
return type_hint is type or get_origin(type_hint) is type
......@@ -171,6 +140,10 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return type_hints
def is_online_quantization(quantization: Any) -> bool:
return quantization in ["inc"]
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
......@@ -199,14 +172,17 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name] = {"default": default, "help": help}
# Set other kwargs based on the type hints
json_tip = """\n\nShould either be a valid JSON string or JSON keys
passed individually. For example, the following sets of arguments are
equivalent:\n\n
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
Additionally, list elements can be passed individually using '+':
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
json_tip = """Should either be a valid JSON string or JSON keys
passed individually. For example, the following sets of arguments are
equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Additionally, list elements can be passed individually using `+`:
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`"""
if dataclass_cls is not None:
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
......@@ -218,7 +194,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
raise argparse.ArgumentTypeError(repr(e)) from e
kwargs[name]["type"] = parse_dataclass
kwargs[name]["help"] += json_tip
kwargs[name]["help"] += f"\n\n{json_tip}"
elif contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
......@@ -254,7 +230,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict):
kwargs[name]["type"] = parse_type(json.loads)
kwargs[name]["help"] += json_tip
kwargs[name]["help"] += f"\n\n{json_tip}"
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str
......@@ -320,9 +296,11 @@ class EngineArgs:
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_eplb: bool = ParallelConfig.enable_eplb
......@@ -338,7 +316,6 @@ class EngineArgs:
CacheConfig.prefix_caching_hash_algo
disable_sliding_window: bool = ModelConfig.disable_sliding_window
disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
use_v2_block_manager: bool = True
swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
......@@ -350,6 +327,7 @@ class EngineArgs:
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = ModelConfig.max_logprobs
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
disable_log_stats: bool = False
revision: Optional[str] = ModelConfig.revision
code_revision: Optional[str] = ModelConfig.code_revision
......@@ -362,15 +340,9 @@ class EngineArgs:
enforce_eager: bool = ModelConfig.enforce_eager
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
# The following three fields are deprecated and will be removed in a future
# release. Setting them will have no effect. Please remove them from your
# configurations.
tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str,
Any]] = get_field(MultiModalConfig,
"media_io_kwargs")
......@@ -383,19 +355,14 @@ class EngineArgs:
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[Dict[str, str]] = \
LoRAConfig.default_mm_loras
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_target_modules: Optional[List[str]] = LoRAConfig.lora_target_modules
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
......@@ -428,7 +395,6 @@ class EngineArgs:
speculative_config: Optional[Dict[str, Any]] = None
num_speculative_heads: Optional[int] = None
qlora_adapter_name_or_path: Optional[str] = None
show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version
otlp_traces_endpoint: Optional[str] = \
......@@ -462,7 +428,6 @@ class EngineArgs:
additional_config: dict[str, Any] = \
get_field(VllmConfig, "additional_config")
enable_reasoning: Optional[bool] = None # DEPRECATED
reasoning_parser: str = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
......@@ -471,6 +436,10 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
......@@ -478,13 +447,6 @@ class EngineArgs:
if isinstance(self.compilation_config, (int, dict)):
self.compilation_config = CompilationConfig.from_cli(
str(self.compilation_config))
if self.qlora_adapter_name_or_path is not None:
warnings.warn(
"The `qlora_adapter_name_or_path` is deprecated "
"and will be removed in v0.10.0. ",
DeprecationWarning,
stacklevel=2,
)
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
......@@ -531,6 +493,8 @@ class EngineArgs:
**model_kwargs["max_seq_len_to_capture"])
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode",
**model_kwargs["logprobs_mode"])
model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"])
model_group.add_argument("--disable-cascade-attn",
......@@ -597,14 +561,6 @@ class EngineArgs:
**load_kwargs["ignore_patterns"])
load_group.add_argument("--use-tqdm-on-load",
**load_kwargs["use_tqdm_on_load"])
load_group.add_argument(
"--qlora-adapter-name-or-path",
type=str,
default=None,
help="The `--qlora-adapter-name-or-path` has no effect, do not set"
" it, and it will be removed in v0.10.0.",
deprecated=True,
)
load_group.add_argument('--pt-load-map-location',
**load_kwargs["pt_load_map_location"])
......@@ -625,15 +581,6 @@ class EngineArgs:
guided_decoding_group.add_argument(
"--guided-decoding-disable-additional-properties",
**guided_decoding_kwargs["disable_additional_properties"])
guided_decoding_group.add_argument(
"--enable-reasoning",
action=argparse.BooleanOptionalAction,
deprecated=True,
help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as "
"of v0.9.0. Use `--reasoning-parser` to specify the reasoning "
"parser backend instead. This flag (`--enable-reasoning`) will be "
"removed in v0.10.0. When `--reasoning-parser` is specified, "
"reasoning mode is automatically enabled.")
guided_decoding_group.add_argument(
"--reasoning-parser",
# This choices is a special case because it's not static
......@@ -662,6 +609,11 @@ class EngineArgs:
type=int,
help='Data parallel rank of this instance. '
'When set, enables external load balancer mode.')
parallel_group.add_argument('--data-parallel-start-rank',
'-dpr',
type=int,
help='Starting data parallel rank '
'for secondary nodes.')
parallel_group.add_argument('--data-parallel-size-local',
'-dpl',
type=int,
......@@ -683,6 +635,9 @@ class EngineArgs:
default='mp',
help='Backend for data parallel, either '
'"mp" or "ray".')
parallel_group.add_argument(
"--data-parallel-hybrid-lb",
**parallel_kwargs["data_parallel_hybrid_lb"])
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
......@@ -736,19 +691,6 @@ class EngineArgs:
cache_group.add_argument("--calculate-kv-scales",
**cache_kwargs["calculate_kv_scales"])
# Tokenizer arguments
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
tokenizer_group = parser.add_argument_group(
title="TokenizerPoolConfig",
description=TokenizerPoolConfig.__doc__,
)
tokenizer_group.add_argument("--tokenizer-pool-size",
**tokenizer_kwargs["pool_size"])
tokenizer_group.add_argument("--tokenizer-pool-type",
**tokenizer_kwargs["pool_type"])
tokenizer_group.add_argument("--tokenizer-pool-extra-config",
**tokenizer_kwargs["extra_config"])
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
multimodal_group = parser.add_argument_group(
......@@ -765,6 +707,9 @@ class EngineArgs:
multimodal_group.add_argument(
"--disable-mm-preprocessor-cache",
**multimodal_kwargs["disable_mm_preprocessor_cache"])
multimodal_group.add_argument(
"--interleave-mm-strings",
**multimodal_kwargs["interleave_mm_strings"])
# LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig)
......@@ -789,39 +734,12 @@ class EngineArgs:
"--lora-dtype",
**lora_kwargs["lora_dtype"],
)
lora_group.add_argument("--long-lora-scaling-factors",
**lora_kwargs["long_lora_scaling_factors"])
lora_group.add_argument("--max-cpu-loras",
**lora_kwargs["max_cpu_loras"])
lora_group.add_argument("--fully-sharded-loras",
**lora_kwargs["fully_sharded_loras"])
# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
prompt_adapter_group = parser.add_argument_group(
title="PromptAdapterConfig",
description=PromptAdapterConfig.__doc__,
)
prompt_adapter_group.add_argument(
"--enable-prompt-adapter",
action=argparse.BooleanOptionalAction,
help="If True, enable handling of PromptAdapters.")
prompt_adapter_group.add_argument(
"--max-prompt-adapters",
**prompt_adapter_kwargs["max_prompt_adapters"])
prompt_adapter_group.add_argument(
"--max-prompt-adapter-token",
**prompt_adapter_kwargs["max_prompt_adapter_token"])
# Device arguments
device_kwargs = get_kwargs(DeviceConfig)
device_group = parser.add_argument_group(
title="DeviceConfig",
description=DeviceConfig.__doc__,
)
device_group.add_argument("--device",
**device_kwargs["device"],
deprecated=True)
lora_group.add_argument("--default-mm-loras",
**lora_kwargs["default_mm_loras"])
# Speculative arguments
speculative_group = parser.add_argument_group(
......@@ -911,6 +829,8 @@ class EngineArgs:
scheduler_group.add_argument(
"--disable-hybrid-kv-cache-manager",
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
scheduler_group.add_argument("--async-scheduling",
**scheduler_kwargs["async_scheduling"])
# vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig)
......@@ -928,18 +848,15 @@ class EngineArgs:
**vllm_kwargs["additional_config"])
# Other arguments
parser.add_argument('--use-v2-block-manager',
action='store_true',
default=True,
deprecated=True,
help='[DEPRECATED] block manager v1 has been '
'removed and SelfAttnBlockSpaceManager (i.e. '
'block manager v2) is now the default. '
'Setting this flag to True or False'
' has no effect on vLLM behavior.')
parser.add_argument('--disable-log-stats',
action='store_true',
help='Disable logging statistics.')
parser.add_argument('--enable-prompt-adapter',
action='store_true',
deprecated=True,
help='[DEPRECATED] Prompt adapter has been '
'removed. Setting this flag to True or False'
' has no effect on vLLM behavior.')
return parser
......@@ -985,12 +902,14 @@ class EngineArgs:
enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
logprobs_mode=self.logprobs_mode,
disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs,
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
......@@ -1007,14 +926,33 @@ class EngineArgs:
enable_chunked_prefill=self.enable_chunked_prefill
)
def validate_tensorizer_args(self):
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig)
for key in self.model_loader_extra_config:
if key in TensorizerConfig._fields:
self.model_loader_extra_config["tensorizer_config"][
key] = self.model_loader_extra_config[key]
def create_load_config(self) -> LoadConfig:
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
if self.load_format == "tensorizer":
if hasattr(self.model_loader_extra_config, "to_serializable"):
self.model_loader_extra_config = (
self.model_loader_extra_config.to_serializable())
self.model_loader_extra_config["tensorizer_config"] = {}
self.model_loader_extra_config["tensorizer_config"][
"tensorizer_dir"] = self.model
self.validate_tensorizer_args()
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
device="cpu"
if is_online_quantization(self.quantization) else None,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
......@@ -1056,6 +994,7 @@ class EngineArgs:
def create_engine_config(
self,
usage_context: Optional[UsageContext] = None,
headless: bool = False,
) -> VllmConfig:
"""
Create the VllmConfig.
......@@ -1070,7 +1009,6 @@ class EngineArgs:
If VLLM_USE_V1 is specified by the user but the VllmConfig
is incompatible, we raise an error.
"""
from vllm.platforms import current_platform
current_platform.pre_register_and_update()
device_config = DeviceConfig(
......@@ -1097,9 +1035,16 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context, model_config)
# Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
if current_platform.is_cpu(
) and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC, CpuArchEnum.ARM):
logger.info(
"Chunked prefill is not supported for ARM and POWER CPUs; "
"disabling it for V1 backend.")
self.enable_chunked_prefill = False
else:
self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
......@@ -1138,15 +1083,41 @@ class EngineArgs:
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
assert not headless or not self.data_parallel_hybrid_lb, (
"data_parallel_hybrid_lb is not applicable in "
"headless mode")
data_parallel_external_lb = self.data_parallel_rank is not None
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
assert self.data_parallel_size_local in (1, None), (
"data_parallel_size_local must be 1 when data_parallel_rank "
"is set")
data_parallel_size_local = 1
# Use full external lb if we have local_size of 1.
self.data_parallel_hybrid_lb = False
elif self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local
if self.data_parallel_start_rank and not headless:
# Infer hybrid LB mode.
self.data_parallel_hybrid_lb = True
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
# Use full external lb if we have local_size of 1.
data_parallel_external_lb = True
self.data_parallel_hybrid_lb = False
if data_parallel_size_local == self.data_parallel_size:
# Disable hybrid LB mode if set for a single node
self.data_parallel_hybrid_lb = False
self.data_parallel_rank = self.data_parallel_start_rank or 0
else:
assert not self.data_parallel_hybrid_lb, (
"data_parallel_size_local must be set to use "
"data_parallel_hybrid_lb.")
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size
......@@ -1173,6 +1144,26 @@ class EngineArgs:
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port
if self.async_scheduling:
# Async scheduling does not work with the uniprocess backend.
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "mp"
logger.info("Using mp-based distributed executor backend "
"for async scheduling.")
if self.distributed_executor_backend == "uni":
raise ValueError("Async scheduling is not supported with "
"uni-process backend.")
if self.pipeline_parallel_size > 1:
raise ValueError("Async scheduling is not supported with "
"pipeline-parallel-size > 1.")
# Currently, async scheduling does not support speculative decoding.
# TODO(woosuk): Support it.
if self.speculative_config is not None:
raise ValueError(
"Currently, speculative decoding is not supported with "
"async scheduling.")
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
......@@ -1183,6 +1174,7 @@ class EngineArgs:
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts,
......@@ -1216,7 +1208,6 @@ class EngineArgs:
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
raise ValueError("Multi-Step Chunked-Prefill is not supported "
"for pipeline-parallel-size > 1")
from vllm.platforms import current_platform
if current_platform.is_cpu():
logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
"currently not supported for CPUs and has been "
......@@ -1254,15 +1245,21 @@ class EngineArgs:
long_prefill_token_threshold=self.long_prefill_token_threshold,
disable_hybrid_kv_cache_manager=self.
disable_hybrid_kv_cache_manager,
async_scheduling=self.async_scheduling,
)
if not model_config.is_multimodal_model and self.default_mm_loras:
raise ValueError(
"Default modality-specific LoRA(s) were provided for a "
"non multimodal model")
lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None,
......@@ -1274,11 +1271,6 @@ class EngineArgs:
load_config = self.create_load_config()
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
backend=self.guided_decoding_backend,
disable_fallback=self.guided_decoding_disable_fallback,
......@@ -1289,8 +1281,8 @@ class EngineArgs:
)
observability_config = ObservabilityConfig(
show_hidden_metrics_for_version=self.
show_hidden_metrics_for_version,
show_hidden_metrics_for_version=(
self.show_hidden_metrics_for_version),
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_detailed_traces=self.collect_detailed_traces,
)
......@@ -1306,7 +1298,6 @@ class EngineArgs:
load_config=load_config,
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
......@@ -1366,7 +1357,6 @@ class EngineArgs:
# Skip this check if we are running on a non-GPU platform,
# or if the device capability is not available
# (e.g. in a Ray actor without GPUs).
from vllm.platforms import current_platform
if (current_platform.is_cuda()
and current_platform.get_device_capability()
and current_platform.get_device_capability().major < 8):
......@@ -1376,34 +1366,16 @@ class EngineArgs:
# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
fp8_attention = self.kv_cache_dtype.startswith("fp8")
will_use_fa = (
current_platform.is_cuda()
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if current_platform.is_rocm():
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
supported = current_platform.is_kv_cache_dtype_supported(
self.kv_cache_dtype)
int8_attention = self.kv_cache_dtype.startswith("int8")
if int8_attention:
supported = True
if not supported:
_raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False)
return False
# No Prompt Adapter so far.
if self.enable_prompt_adapter:
_raise_or_fallback(feature_name="--enable-prompt-adapter",
recommend_to_remove=False)
return False
# No text embedding inputs so far.
if self.enable_prompt_embeds:
_raise_or_fallback(feature_name="--enable-prompt-embeds",
......@@ -1437,28 +1409,12 @@ class EngineArgs:
return False
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
is_ngram_enabled = False
is_eagle_enabled = False
is_medusa_enabled = False
if self.speculative_config is not None:
# This is supported but experimental (handled below).
speculative_method = self.speculative_config.get("method")
if speculative_method:
if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True
elif speculative_method == "medusa":
is_medusa_enabled = True
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
if speculative_model in ("ngram", "[ngram]"):
is_ngram_enabled = True
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback(feature_name="Speculative Decoding",
recommend_to_remove=False)
return False
if (self.speculative_config is not None
and self.speculative_config.get("method") == "draft_model"):
raise NotImplementedError(
"Speculative decoding with draft model is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.")
# No XFormers so far.
V1_BACKENDS = [
......@@ -1534,7 +1490,6 @@ class EngineArgs:
# Enable chunked prefill by default for long context (> 32K)
# models to avoid OOM errors in initial memory profiling phase.
elif use_long_context:
from vllm.platforms import current_platform
is_gpu = current_platform.is_cuda()
use_sliding_window = (model_config.get_sliding_window()
is not None)
......@@ -1542,7 +1497,6 @@ class EngineArgs:
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and model_config.runner_type != "pooling"):
self.enable_chunked_prefill = True
logger.warning(
......@@ -1636,7 +1590,6 @@ class EngineArgs:
# as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs.
from vllm.platforms import current_platform
try:
device_memory = current_platform.get_device_total_memory()
device_name = current_platform.get_device_name().lower()
......@@ -1647,6 +1600,7 @@ class EngineArgs:
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
# So here we do an extra device name check to prevent such regression.
from vllm.usage.usage_lib import UsageContext
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
# For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = {
......@@ -1685,13 +1639,14 @@ class EngineArgs:
# cpu specific default values.
if current_platform.is_cpu():
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 4096,
UsageContext.OPENAI_API_SERVER: 2048,
UsageContext.LLM_CLASS: 4096 * world_size,
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 128,
UsageContext.OPENAI_API_SERVER: 32,
UsageContext.LLM_CLASS: 256 * world_size,
UsageContext.OPENAI_API_SERVER: 128 * world_size,
}
use_context_value = usage_context.value if usage_context else None
......@@ -1739,7 +1694,6 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')
from vllm.platforms import current_platform
current_platform.pre_register_and_update(parser)
return parser
......
......@@ -29,7 +29,6 @@ from vllm.model_executor.guided_decoding import (
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -435,9 +434,9 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Async version of
......@@ -467,7 +466,7 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
tokenization_kwargs=tokenization_kwargs,
)
if isinstance(params, SamplingParams) and \
......@@ -489,7 +488,6 @@ class _AsyncLLMEngine(LLMEngine):
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
......@@ -859,9 +857,9 @@ class AsyncLLMEngine(EngineClient):
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
......@@ -886,9 +884,9 @@ class AsyncLLMEngine(EngineClient):
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs,
)
return stream.generator()
......@@ -900,7 +898,6 @@ class AsyncLLMEngine(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
......@@ -918,8 +915,6 @@ class AsyncLLMEngine(EngineClient):
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must
......@@ -979,7 +974,6 @@ class AsyncLLMEngine(EngineClient):
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
):
......@@ -996,6 +990,7 @@ class AsyncLLMEngine(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
......@@ -1070,6 +1065,7 @@ class AsyncLLMEngine(EngineClient):
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError:
......
......@@ -45,7 +45,6 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
......@@ -227,7 +226,6 @@ class LLMEngine:
self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
......@@ -242,18 +240,18 @@ class LLMEngine:
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init and self.model_config.tokenizer_mode != "cpm":
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
elif self.model_config.tokenizer_mode == "cpm":
self.tokenizer = CPM9GTokenizer(self.model_config.model, trust_remote_code=True)
self.detokenizer = Detokenizer(self.tokenizer, self.model_config.tokenizer_mode)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
......@@ -302,8 +300,6 @@ class LLMEngine:
# Feature flags
"enable_lora":
bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_prefix_caching":
self.cache_config.enable_prefix_caching,
"enforce_eager":
......@@ -556,9 +552,6 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _add_processed_request(
self,
......@@ -567,7 +560,6 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
......@@ -583,7 +575,6 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return None
......@@ -601,11 +592,10 @@ class LLMEngine:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
lora_request)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))
seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
......@@ -616,7 +606,6 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams):
......@@ -626,7 +615,6 @@ class LLMEngine:
params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
else:
......@@ -655,7 +643,6 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Add a request to the engine's request pool.
......@@ -676,7 +663,6 @@ class LLMEngine:
the current monotonic time.
lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
priority: The priority of the request.
Only applicable with priority scheduling.
......@@ -741,7 +727,6 @@ class LLMEngine:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
......@@ -750,7 +735,6 @@ class LLMEngine:
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
......@@ -763,7 +747,6 @@ class LLMEngine:
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
......@@ -791,17 +774,15 @@ class LLMEngine:
if self.vllm_config.speculative_config is not None:
draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority,
draft_size=draft_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
encoder_seq=encoder_seq,
priority=priority,
draft_size=draft_size)
return seq_group
......@@ -812,7 +793,6 @@ class LLMEngine:
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
......@@ -820,15 +800,13 @@ class LLMEngine:
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
encoder_seq=encoder_seq,
priority=priority)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
......@@ -1816,13 +1794,6 @@ class LLMEngine:
num_generation_tokens_from_prefill_groups)
num_tokens_iter = (num_generation_tokens_iter +
num_prompt_tokens_iter)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and isinstance(model_output[0], SamplerOutput) and (
model_output[0].spec_decode_worker_metrics is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None
return Stats(
now=now,
......@@ -1844,7 +1815,6 @@ class LLMEngine:
num_tokens_iter=num_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
num_preemption_iter=num_preemption_iter,
# Request stats
......@@ -1878,16 +1848,6 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def start_profile(self) -> None:
self.model_executor.start_profile()
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Type, Union, cast
......@@ -19,9 +18,6 @@ if ray is not None:
else:
ray_metrics = None
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__)
prometheus_client.disable_created_metrics()
......@@ -199,30 +195,6 @@ class Metrics:
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Speculative decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames,
multiprocess_mode="sum")
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))
# --8<-- [end:metrics-definitions]
......@@ -391,9 +363,6 @@ class LoggingStatLogger(StatLoggerBase):
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
......@@ -435,10 +404,6 @@ class LoggingStatLogger(StatLoggerBase):
stats.gpu_prefix_cache_hit_rate * 100,
stats.cpu_prefix_cache_hit_rate * 100,
)
if self.spec_decode_metrics is not None:
log_fn(
self._format_spec_decode_metrics_str(
self.spec_decode_metrics))
self._reset(stats, prompt_throughput, generation_throughput)
......@@ -447,21 +412,9 @@ class LoggingStatLogger(StatLoggerBase):
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
self.spec_decode_metrics = None
self.last_prompt_throughput = prompt_throughput
self.last_generation_throughput = generation_throughput
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens: {metrics.emitted_tokens}.")
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
......@@ -579,33 +532,14 @@ class PrometheusStatLogger(StatLoggerBase):
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
if self.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
self.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
self.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
self.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
self.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
self.spec_decode_metrics.emitted_tokens)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
self.spec_decode_metrics = None
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
# Info type metrics are syntactic sugar for a gauge permanently set to 1
......
......@@ -16,10 +16,9 @@ do this in Python code and lazily import prometheus_client.
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from typing import List
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
@dataclass
......@@ -65,8 +64,6 @@ class Stats:
running_lora_adapters: List[str]
max_lora: str
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
......@@ -77,7 +74,6 @@ class StatLoggerBase(ABC):
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
@abstractmethod
def log(self, stats: Stats) -> None:
......@@ -86,9 +82,3 @@ class StatLoggerBase(ABC):
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics
......@@ -10,7 +10,6 @@ from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import Device
......@@ -33,7 +32,6 @@ class RPCProcessRequest:
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
def __init__(
......@@ -43,7 +41,6 @@ class RPCProcessRequest:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
super().__init__()
......@@ -53,7 +50,6 @@ class RPCProcessRequest:
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
......
......@@ -45,7 +45,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
......@@ -453,7 +452,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
......@@ -470,8 +468,6 @@ class MQLLMEngineClient(EngineClient):
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
......@@ -479,8 +475,7 @@ class MQLLMEngineClient(EngineClient):
return cast(
AsyncGenerator[RequestOutput, None],
self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request, priority))
lora_request, trace_headers, priority))
def encode(
self,
......@@ -526,7 +521,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
......@@ -580,7 +574,6 @@ class MQLLMEngineClient(EngineClient):
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
))
......
......@@ -322,14 +322,12 @@ class MQLLMEngine:
self._send_outputs(rpc_err)
try:
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
self.engine.add_request(request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
priority=request.priority)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
......
......@@ -104,11 +104,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED)
for output in outputs:
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
sequence_group.metrics.spec_token_acceptance_counts[
output.step_index] += 1
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
......
......@@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid
......@@ -55,7 +54,6 @@ class EngineClient(ABC):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
......@@ -324,3 +322,9 @@ class EngineClient(ABC):
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
...
async def scale_elastic_ep(self,
new_data_parallel_size: int,
drain_timeout: int = 300) -> None:
"""Scale the engine"""
raise NotImplementedError
......@@ -4,7 +4,7 @@
import asyncio
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections import Counter, defaultdict, deque
from collections.abc import Awaitable, Iterable
from functools import cached_property, lru_cache, partial
from pathlib import Path
......@@ -28,6 +28,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
from openai.types.responses import ResponseInputImageParam
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
# yapf: enable
......@@ -38,7 +39,6 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_cls
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
......@@ -52,6 +52,12 @@ from vllm.utils import deprecate_kwargs, random_uuid
logger = init_logger(__name__)
MODALITY_PLACEHOLDERS_MAP = {
"image": "<##IMAGE##>",
"audio": "<##AUDIO##>",
"video": "<##VIDEO##>",
}
class AudioURL(TypedDict, total=False):
url: Required[str]
......@@ -145,6 +151,27 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
video_url: Required[str]
class CustomThinkCompletionContentParam(TypedDict, total=False):
"""A Think Completion Content Param that accepts a plain text and a boolean.
Example:
{
"thinking": "I am thinking about the answer",
"closed": True,
"type": "thinking"
}
"""
thinking: Required[str]
"""The thinking content."""
closed: bool
"""Whether the thinking is closed."""
type: Required[Literal["thinking"]]
"""The thinking type."""
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
......@@ -153,7 +180,8 @@ ChatCompletionContentPartParam: TypeAlias = Union[
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]
CustomChatCompletionContentSimpleVideoParam, str,
CustomThinkCompletionContentParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
......@@ -354,6 +382,7 @@ def resolve_mistral_chat_template(
"so it will be ignored.")
return None
@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
......@@ -517,6 +546,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
@cached_property
def model_cls(self):
from vllm.model_executor.model_loader import get_model_cls
return get_model_cls(self.model_config)
@property
......@@ -633,15 +663,22 @@ class BaseMultiModalContentParser(ABC):
def __init__(self) -> None:
super().__init__()
# multimodal placeholder_string : count
self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0)
def _add_placeholder(self, placeholder: Optional[str]):
# stores model placehodlers list with corresponding
# general MM placeholder:
# {
# "<##IMAGE##>": ["<image>", "<image>", "<image>"],
# "<##AUDIO##>": ["<audio>", "<audio>"]
# }
self._placeholder_storage: dict[str, list] = defaultdict(list)
def _add_placeholder(self, modality: ModalityStr,
placeholder: Optional[str]):
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
if placeholder:
self._placeholder_counts[placeholder] += 1
self._placeholder_storage[mod_placeholder].append(placeholder)
def mm_placeholder_counts(self) -> dict[str, int]:
return dict(self._placeholder_counts)
def mm_placeholder_storage(self) -> dict[str, list]:
return dict(self._placeholder_storage)
@abstractmethod
def parse_image(self, image_url: str) -> None:
......@@ -685,7 +722,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
image = self._connector.fetch_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
......@@ -700,17 +737,17 @@ class MultiModalContentParser(BaseMultiModalContentParser):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None:
placeholder = self._tracker.add("image", image_pil)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
self._add_placeholder("audio", placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
......@@ -723,7 +760,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
video = self._connector.fetch_video(video_url=video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
self._add_placeholder("video", placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
......@@ -741,7 +778,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_coro = self._connector.fetch_image_async(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
......@@ -760,20 +797,20 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None:
future: asyncio.Future[Image.Image] = asyncio.Future()
future.set_result(image_pil)
placeholder = self._tracker.add("image", future)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
self._add_placeholder("audio", placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
......@@ -786,7 +823,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
video = self._connector.fetch_video_async(video_url=video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
self._add_placeholder("video", placeholder)
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
......@@ -856,12 +893,40 @@ def load_chat_template(
return _cached_load_chat_template(chat_template, is_literal=is_literal)
def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
texts: list[str]) -> str:
for idx, elem in enumerate(texts):
if elem in placeholder_storage:
texts[idx] = placeholder_storage[elem].pop(0)
return "\n".join(texts)
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
text_prompt: str) -> str:
def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
texts: list[str],
interleave_strings: bool
) -> str:
"""Combine multimodal prompts for a multimodal language model."""
# flatten storage to make it looks like
# {
# "<|image|>": 2,
# "<|audio|>": 1
# }
placeholder_counts = Counter(
[v for elem in placeholder_storage.values() for v in elem]
)
if interleave_strings:
text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
else:
text_prompt = "\n".join(texts)
# Pass interleaved text further in case the user used image placeholders
# himself, but forgot to disable the 'interleave_strings' flag
# Look through the text prompt to check for missing placeholders
missing_placeholders: list[str] = []
for placeholder in placeholder_counts:
......@@ -870,6 +935,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
if placeholder_counts[placeholder] < 0:
logger.error(
"Placeholder count is negative! "
"Ensure that the 'interleave_strings' flag is disabled "
"(current value: %s) "
"when manually placing image placeholders.", interleave_strings
)
logger.debug("Input prompt: %s", text_prompt)
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")
......@@ -877,8 +949,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])
# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
# NOTE: Default behaviour: we always add missing placeholders
# at the front of the prompt, if interleave_strings=False
return "\n".join(missing_placeholders + [text_prompt])
......@@ -888,11 +960,14 @@ _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ResponsesInputImageParser = TypeAdapter(
ResponseInputImageParam).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
# Define a mapping from part types to their corresponding parsing functions.
......@@ -902,6 +977,12 @@ MM_PARSER_MAP: dict[
] = {
"text":
lambda part: _TextParser(part).get("text", None),
"thinking":
lambda part: _ThinkParser(part).get("thinking", None),
"input_text":
lambda part: _TextParser(part).get("text", None),
"input_image":
lambda part: _ResponsesInputImageParser(part).get("image_url", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds":
......@@ -986,6 +1067,7 @@ def _parse_chat_message_content_parts(
mm_tracker: BaseMultiModalItemTracker,
*,
wrap_dicts: bool,
interleave_strings: bool,
) -> list[ConversationMessage]:
content = list[_ContentPart]()
......@@ -996,6 +1078,7 @@ def _parse_chat_message_content_parts(
part,
mm_parser,
wrap_dicts=wrap_dicts,
interleave_strings=interleave_strings
)
if parse_res:
content.append(parse_res)
......@@ -1005,11 +1088,14 @@ def _parse_chat_message_content_parts(
return [ConversationMessage(role=role,
content=content)] # type: ignore
texts = cast(list[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
mm_placeholder_storage = mm_parser.mm_placeholder_storage()
if mm_placeholder_storage:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage,
texts,
interleave_strings)
else:
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
......@@ -1018,6 +1104,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser,
*,
wrap_dicts: bool,
interleave_strings: bool,
) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
......@@ -1028,10 +1115,8 @@ def _parse_chat_message_content_part(
"""
if isinstance(part, str): # Handle plain text parts
return part
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
......@@ -1040,41 +1125,44 @@ def _parse_chat_message_content_part(
"with empty / unparsable content.", part, part_type)
return None
if part_type in ("text", "refusal"):
if part_type in ("text", "input_text", "refusal", "thinking"):
str_content = cast(str, content)
if wrap_dicts:
return {'type': 'text', 'text': str_content}
else:
return str_content
modality = None
if part_type == "image_pil":
image_content = cast(Image.Image, content)
mm_parser.parse_image_pil(image_content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_url":
modality = "image"
elif part_type in ("image_url", "input_image"):
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_embeds":
modality = "image"
elif part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url":
modality = "image"
elif part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio":
modality = "audio"
elif part_type == "input_audio":
dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None
if part_type == "video_url":
modality = "audio"
elif part_type == "video_url":
str_content = cast(str, content)
mm_parser.parse_video(str_content)
return {'type': 'video'} if wrap_dicts else None
modality = "video"
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
raise NotImplementedError(f"Unknown part type: {part_type}")
return {'type': modality} if wrap_dicts else (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
)
# No need to validate using Pydantic again
......@@ -1086,6 +1174,7 @@ def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
content_format: _ChatTemplateContentFormat,
interleave_strings: bool,
) -> list[ConversationMessage]:
role = message["role"]
content = message.get("content")
......@@ -1101,6 +1190,7 @@ def _parse_chat_message_content(
content, # type: ignore
mm_tracker,
wrap_dicts=(content_format == "openai"),
interleave_strings=interleave_strings,
)
for result_msg in result:
......@@ -1153,6 +1243,11 @@ def parse_chat_messages(
msg,
mm_tracker,
content_format,
interleave_strings=(
content_format == "string"
and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings
)
)
conversation.extend(sub_messages)
......@@ -1176,6 +1271,11 @@ def parse_chat_messages_futures(
msg,
mm_tracker,
content_format,
interleave_strings=(
content_format == "string"
and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings
)
)
conversation.extend(sub_messages)
......
......@@ -7,17 +7,6 @@ to avoid certain eager import breakage.'''
from __future__ import annotations
import importlib.metadata
import signal
import sys
def register_signal_handlers():
def signal_handler(sig, frame):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTSTP, signal_handler)
def main():
......
......@@ -55,7 +55,7 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
try:
input_message = input("> ")
except EOFError:
return
break
conversation.append({"role": "user", "content": input_message})
chat_completion = client.chat.completions.create(model=model_name,
......@@ -118,7 +118,7 @@ class ChatCommand(CLISubcommand):
try:
input_message = input("> ")
except EOFError:
return
break
conversation.append({"role": "user", "content": input_message})
chat_completion = client.chat.completions.create(
......@@ -170,7 +170,10 @@ class CompleteCommand(CLISubcommand):
print("Please enter prompt to complete:")
while True:
input_prompt = input("> ")
try:
input_prompt = input("> ")
except EOFError:
break
completion = client.completions.create(model=model_name,
prompt=input_prompt)
output = completion.choices[0].text
......
......@@ -45,9 +45,6 @@ class ServeSubcommand(CLISubcommand):
if args.headless or args.api_server_count < 1:
run_headless(args)
else:
if args.data_parallel_start_rank:
raise ValueError("data_parallel_start_rank is only "
"applicable in headless mode")
if args.api_server_count > 1:
run_multi_api_server(args)
else:
......@@ -65,36 +62,6 @@ class ServeSubcommand(CLISubcommand):
help="Start the vLLM OpenAI Compatible API server.",
description="Start the vLLM OpenAI Compatible API server.",
usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag",
type=str,
nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument(
"--headless",
action='store_true',
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
serve_parser.add_argument(
'--data-parallel-start-rank',
'-dpr',
type=int,
default=0,
help='Starting data parallel rank for secondary nodes.')
serve_parser.add_argument('--api-server-count',
'-asc',
type=int,
default=1,
help='How many API server processes to run.')
serve_parser.add_argument(
"--config",
type=str,
default='',
required=False,
help="Read CLI options from a config file. "
"Must be a YAML with the following options: "
"https://docs.vllm.ai/en/latest/configuration/serve_args.html")
serve_parser = make_arg_parser(serve_parser)
show_filtered_argument_or_group_from_help(serve_parser, ["serve"])
......@@ -114,13 +81,14 @@ def run_headless(args: argparse.Namespace):
# Create the EngineConfig.
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
vllm_config = engine_args.create_engine_config(usage_context=usage_context,
headless=True)
if not envs.VLLM_USE_V1:
raise ValueError("Headless mode is only supported for V1")
if engine_args.data_parallel_rank is not None:
raise ValueError("data_parallel_rank is not applicable in "
if engine_args.data_parallel_hybrid_lb:
raise ValueError("data_parallel_hybrid_lb is not applicable in "
"headless mode")
parallel_config = vllm_config.parallel_config
......@@ -150,7 +118,7 @@ def run_headless(args: argparse.Namespace):
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=args.data_parallel_start_rank,
start_index=vllm_config.parallel_config.data_parallel_rank,
local_start_index=0,
vllm_config=vllm_config,
local_client=False,
......@@ -197,6 +165,11 @@ def run_multi_api_server(args: argparse.Namespace):
" api_server_count > 1")
model_config.disable_mm_preprocessor_cache = True
if vllm_config.parallel_config.data_parallel_hybrid_lb:
raise NotImplementedError(
"Hybrid load balancing with --api-server-count > 0"
"is not yet supported.")
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment