Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
......@@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.utils import GiB_bytes, sha256
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
......@@ -310,8 +310,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
......@@ -322,8 +321,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx]["offset"]
length = mm_positions[curr_mm_idx]["length"]
offset = mm_positions[curr_mm_idx].offset
length = mm_positions[curr_mm_idx].length
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
......@@ -460,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
return ret
def estimate_max_model_len(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> int:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
# Define a function to check if a given model length fits in memory
def fits_in_memory(model_len: int) -> bool:
# Modify the max_model_len for this calculation
vllm_config.model_config.max_model_len = model_len
# Calculate memory needed for the given model length
memory_needed = sum(
(layer_spec.max_memory_usage_bytes(vllm_config)
for layer_spec in kv_cache_spec.values()),
start=0,
)
return memory_needed <= available_memory
# Binary search for the maximum model length
current_max = vllm_config.model_config.max_model_len
left, right = 1, current_max
# If even the smallest model length doesn't fit, return 0
if not fits_in_memory(left):
return 0
# Binary search for the maximum model length that fits
result = 1
while left <= right:
mid = (left + right) // 2
if fits_in_memory(mid):
result = mid
left = mid + 1
else:
right = mid - 1
return result
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
......@@ -487,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
if needed_memory > available_memory:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
available_memory)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = " Based on the available memory,"
f" the estimated maximum model length is {estimated_max_len}."
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV "
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try "
f"increasing `gpu_memory_utilization` or decreasing "
f"memory ({available_memory/GiB_bytes:.2f} GiB)."
f"{estimated_msg} "
f" Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")
......
......@@ -7,7 +7,8 @@ from collections import deque
from collections.abc import Iterable
from typing import Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
......@@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
lora_config: Optional[LoRAConfig],
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
speculative_config: SpeculativeConfig = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
......@@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)
self.num_lookahead_tokens = 0
if speculative_config and speculative_config.method == "eagle":
self.num_lookahead_tokens = \
speculative_config.num_speculative_tokens
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
......@@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens)
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
......@@ -505,8 +514,8 @@ class Scheduler(SchedulerInterface):
assert mm_positions is not None
assert len(mm_positions) > 0
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
......@@ -522,6 +531,17 @@ class Scheduler(SchedulerInterface):
if self.encoder_cache_manager.has_cache(request, i):
# The encoder input is already computed and cached.
continue
# If no encoder input chunking is allowed, we do not want to
# partially schedule a multimodal item. If the scheduled range would
# only cover part of the mm input, roll back to before the mm item.
if (self.scheduler_config.disable_chunked_mm_input
and num_computed_tokens < start_pos
and (num_computed_tokens + num_new_tokens)
< (start_pos + num_encoder_tokens)):
num_new_tokens = start_pos - num_computed_tokens
break
if (not self.encoder_cache_manager.can_allocate(request, i)
or num_encoder_tokens > encoder_budget):
# The encoder cache is full or the encoder budget is exhausted.
......@@ -596,8 +616,8 @@ class Scheduler(SchedulerInterface):
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions["offset"]
num_tokens = mm_positions["length"]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
......
......@@ -2,6 +2,7 @@
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
......@@ -52,7 +53,7 @@ class EngineCoreRequest(
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: list[int]
mm_inputs: Optional[list[MultiModalKwargs]]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
......
......@@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
......@@ -98,6 +98,7 @@ class EngineCore:
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
kv_cache_config=kv_cache_config,
speculative_config=vllm_config.speculative_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
......@@ -105,7 +106,7 @@ class EngineCore:
)
# Setup MM Input Mapper.
self.mm_input_cache_server = MMInputCacheServer(
self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config)
# Setup batch queue for pipeline parallelism.
......@@ -173,7 +174,7 @@ class EngineCore:
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
......@@ -318,6 +319,11 @@ class EngineCoreProc(EngineCore):
):
super().__init__(vllm_config, executor_class, log_stats)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.global_unfinished_reqs = False
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
......@@ -327,22 +333,16 @@ class EngineCoreProc(EngineCore):
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
args=(input_path, engine_index),
daemon=True).start()
threading.Thread(target=self.process_output_socket,
args=(output_path, engine_index),
daemon=True).start()
self.global_unfinished_reqs = False
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
@staticmethod
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process."""
......@@ -377,9 +377,6 @@ class EngineCoreProc(EngineCore):
else:
engine_core = EngineCoreProc(*args, **kwargs)
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
engine_core.run_busy_loop()
except SystemExit:
......@@ -476,24 +473,32 @@ class EngineCoreProc(EngineCore):
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def process_input_socket(self, input_path: str):
def process_input_socket(self, input_path: str, engine_index: int):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
identity = engine_index.to_bytes(length=2, byteorder="little")
with zmq_socket_ctx(input_path,
zmq.DEALER,
identity=identity,
bind=False) as socket:
# Send ready message to front-end once input socket is connected.
socket.send(b'READY')
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
type_frame, *data_frames = socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
......@@ -510,8 +515,8 @@ class EngineCoreProc(EngineCore):
while True:
outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer)
socket.send(buffer, copy=False)
buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
......@@ -619,4 +624,4 @@ class DPEngineCoreProc(EngineCoreProc):
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
local_unfinished)
\ No newline at end of file
......@@ -8,7 +8,7 @@ import threading
import uuid
import weakref
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Sequence
from collections.abc import Awaitable
from concurrent.futures import Future
from dataclasses import dataclass, field
from threading import Thread
......@@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__)
......@@ -35,6 +35,8 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
STARTUP_POLL_PERIOD_MS = 10000
class EngineCoreClient(ABC):
"""
......@@ -261,15 +263,13 @@ class CoreEngine:
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
ctx: Union[zmq.Context, zmq.asyncio.Context],
input_path: str,
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
self.index = index
self.identity = index.to_bytes(length=2, byteorder="little")
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
......@@ -291,14 +291,9 @@ class CoreEngine:
# Ensure socket is closed if process fails to start.
self.close()
def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)
@dataclass
......@@ -309,6 +304,7 @@ class BackgroundResources:
ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
shutdown_path: Optional[str] = None
def __call__(self):
......@@ -321,6 +317,8 @@ class BackgroundResources:
# aren't explicitly closed first.
if self.output_socket is not None:
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
......@@ -387,21 +385,56 @@ class MPClient(EngineCoreClient):
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
index, local_dp_rank)
vllm_config, executor_class, log_stats, input_path, self.
output_path, index, local_dp_rank)
# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)
# Wait for engine core process(es) to start.
for engine in self.resources.core_engines:
engine.proc_handle.wait_for_startup()
self._wait_for_engine_startup()
self.utility_results: dict[int, AnyFuture] = {}
def _wait_for_engine_startup(self):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket)
# Wait for engine core process(es) to send ready messages.
identities = set(eng.index for eng in self.resources.core_engines)
poller = zmq.Poller()
poller.register(sync_input_socket, zmq.POLLIN)
for eng in self.resources.core_engines:
poller.register(eng.proc_handle, zmq.POLLIN)
while identities:
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
logger.debug("Waiting for %d core engine proc(s) to start: %s",
len(identities), identities)
continue
if len(events) > 1 or events[0][0] != sync_input_socket:
# One of the core processes exited.
raise RuntimeError("Engine core initialization failed. "
"See root cause above.")
eng_id_bytes, msg = sync_input_socket.recv_multipart()
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
if eng_id not in identities:
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
if msg != b'READY':
raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}")
logger.info("Core engine process %d ready.", eng_id)
identities.discard(eng_id)
def _init_core_engines(
self,
vllm_config: VllmConfig,
......@@ -472,8 +505,8 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
break
frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer)
frames = out_socket.recv_multipart(copy=False)
outputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
......@@ -494,9 +527,10 @@ class SyncMPClient(MPClient):
return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
self.core_engine.send_multipart(msg)
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
*self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
......@@ -599,8 +633,8 @@ class AsyncMPClient(MPClient):
async def process_outputs_socket():
while True:
(frame, ) = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
frames = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
......@@ -625,30 +659,34 @@ class AsyncMPClient(MPClient):
assert self.outputs_queue is not None
return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
def _send_input(self,
request_type: EngineCoreRequestType,
request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
if engine is None:
engine = self.core_engine
self._ensure_output_queue_task()
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine)
def _send_input_message(self, message: tuple[bytestr, ...],
engine: CoreEngine) -> Awaitable[None]:
message = (engine.identity, ) + message
return self.input_socket.send_multipart(message, copy=False)
async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method,
*args,
engine=self.core_engine)
async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
async def _call_utility_async(self, method: str, *args,
engine: CoreEngine) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args)))
await self._send_input_message(message, engine)
self._ensure_output_queue_task()
return await future
......@@ -657,6 +695,7 @@ class AsyncMPClient(MPClient):
# tokenized.
request.prompt = None
await self._send_input(EngineCoreRequestType.ADD, request)
self._ensure_output_queue_task()
async def abort_requests_async(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
......@@ -721,7 +760,7 @@ class DPAsyncMPClient(AsyncMPClient):
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
*self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
......@@ -755,21 +794,21 @@ class DPAsyncMPClient(AsyncMPClient):
# tokenized.
request.prompt = None
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await chosen_engine.send_multipart(msg)
await self._send_input_message(msg, chosen_engine)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
self._send_input_message(
msg if engine is chosen_engine else self.start_dp_msg,
engine) for engine in self.core_engines
])
self._ensure_output_queue_task()
......@@ -794,7 +833,7 @@ class DPAsyncMPClient(AsyncMPClient):
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
engine.send_multipart(self.start_dp_msg)
self._send_input_message(self.start_dp_msg, engine)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
......@@ -820,5 +859,5 @@ class DPAsyncMPClient(AsyncMPClient):
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
engine)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache
from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
......@@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
......@@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
# variable VLLM_MM_INPUT_CACHE_GIB.
class MMInputCacheServer:
class MirroredProcessingCache:
def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def get_and_update(
def get_and_update_p0(
self,
mm_inputs: list[MultiModalKwargs],
mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str],
) -> list[MultiModalKwargs]:
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = []
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_hash in self.mm_cache:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def get_and_update_p1(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
else:
......
# SPDX-License-Identifier: Apache-2.0
import time
from collections.abc import Mapping
from typing import Optional, Union
from collections.abc import Mapping, Sequence
from typing import Literal, Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
......@@ -46,6 +48,8 @@ class Processor:
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
......@@ -73,6 +77,7 @@ class Processor:
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
......@@ -83,6 +88,26 @@ class Processor:
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
def _validate_supported_sampling_params(
self,
params: SamplingParams,
......@@ -136,9 +161,6 @@ class Processor:
f" != {engine_level_backend}")
else:
params.guided_decoding.backend = engine_level_backend
import vllm.platforms
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
# Request content validation
if engine_level_backend.startswith("xgrammar"):
......@@ -181,6 +203,11 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
)
self._validate_lora(lora_request)
self._validate_params(params)
if priority != 0:
......@@ -228,7 +255,7 @@ class Processor:
self.tokenizer.get_lora_tokenizer(lora_request))
# Multimodal related.
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
......@@ -253,20 +280,28 @@ class Processor:
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
sorted_mm_inputs = []
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
]))
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
sorted_mm_inputs = [
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return EngineCoreRequest(
request_id=request_id,
prompt=decoder_inputs.get("prompt"),
......@@ -285,41 +320,64 @@ class Processor:
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
max_input_id = max(prompt_ids)
max_allowed = self.tokenizer.get_lora_tokenizer(
lora_request).max_token_id
if max_input_id > max_allowed:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
if len(prompt_ids) >= self.model_config.max_model_len:
raise ValueError(
f"Prompt length of {len(prompt_ids)} is longer than the "
f"maximum model length of {self.model_config.max_model_len}.")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
if model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
......
......@@ -119,10 +119,9 @@ class MultiprocExecutor(Executor):
timeout=dequeue_timeout)
if status != WorkerProc.ResponseStatus.SUCCESS:
if isinstance(result, Exception):
raise result
else:
raise RuntimeError("Worker failed")
raise RuntimeError(
"Worker failed with error %s, please check the"
" stack trace above for the root cause", result)
responses[w.rank] = result
......@@ -327,7 +326,7 @@ class WorkerProc:
logger.debug("Worker interrupted.")
except Exception:
# worker_busy_loop sends exceptions exceptons to Executor
# worker_busy_loop sends exceptions to Executor
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
psutil.Process().parent().send_signal(signal.SIGUSR1)
......@@ -378,9 +377,11 @@ class WorkerProc:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e))
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
self.worker_response_mq.enqueue(
......
......@@ -239,7 +239,8 @@ class PrometheusStatLogger(StatLoggerBase):
documentation="Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
640.0, 2560.0
],
labelnames=labelnames).labels(*labelvalues)
......@@ -249,13 +250,13 @@ class PrometheusStatLogger(StatLoggerBase):
documentation="Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,
0.75, 1.0, 2.5
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
],
labelnames=labelnames).labels(*labelvalues)
request_latency_buckets = [
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
]
self.histogram_e2e_time_request = \
prometheus_client.Histogram(
......
......@@ -3,17 +3,16 @@
import enum
from typing import TYPE_CHECKING, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
class Request:
......@@ -23,9 +22,9 @@ class Request:
request_id: str,
prompt: Optional[str],
prompt_token_ids: list[int],
multi_modal_inputs: Optional[list["MultiModalKwargs"]],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list["PlaceholderRange"]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
......@@ -75,6 +74,11 @@ class Request:
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None:
assert isinstance(request.mm_inputs, list)
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
"mm_inputs was not updated in EngineCore.add_request")
return cls(
request_id=request.request_id,
prompt=request.prompt,
......@@ -121,7 +125,7 @@ class Request:
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"]
num_tokens = self.mm_positions[input_id].length
return num_tokens
@property
......
......@@ -230,9 +230,19 @@ class Sampler(nn.Module):
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
# Get vocabulary size from logits
vocab_size = logits.shape[-1]
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
# Check token_id bounds to ensure within vocabulary
if token_id < 0 or token_id >= vocab_size:
raise ValueError(
f"token_id {token_id} in logit_bias contains "
f"out-of-vocab token id. Vocabulary size: "
f"{vocab_size}")
logits[i, token_id] += bias
return logits
......
......@@ -3,7 +3,6 @@ from dataclasses import dataclass, field
from typing import Optional
import torch
import torch_xla.core.xla_model as xm
from vllm.v1.worker.gpu_input_batch import InputBatch
......@@ -24,19 +23,15 @@ class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor
temperature: torch.Tensor = None
min_p: torch.Tensor
min_p: torch.Tensor = None
# Still too slow on forward_native!
top_k: torch.Tensor = None
top_p: torch.Tensor = None
# Greedy sampling flag for compiling single xla graph.
all_greedy: torch.Tensor = None
# Generator not supported by xla
generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
all_greedy: bool = True
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None
......@@ -57,64 +52,66 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask = None
bad_words_token_ids = None
indices_do_sample: torch.Tensor = None
# Generator not supported by xla
_generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
@property
def generators(self) -> dict[int, torch.Generator]:
# Generator not supported by torch/xla. This field must be immutable.
return self._generators
@classmethod
def from_input_batch(
cls, input_batch: InputBatch,
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
cls,
input_batch: InputBatch,
padded_num_reqs: int,
xla_device: torch.device,
generate_params_if_all_greedy: bool = False
) -> "TPUSupportedSamplingMetadata":
"""
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
We expect sampling params tensors to be padded to the same fixed shape.
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
Args:
input_batch: The input batch containing sampling parameters.
padded_num_reqs: The padded number of requests.
xla_device: The XLA device.
generate_params_if_all_greedy: If True, generate sampling parameters
even if all requests are greedy. this is useful for cases where
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
# Early return to avoid unnecessary cpu to tpu copy
if (input_batch.all_greedy is True
and generate_params_if_all_greedy is False):
return cls(all_greedy=True)
num_reqs = input_batch.num_reqs
padded_num_reqs = len(indices_do_sample)
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
fill_val) -> torch.Tensor:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
# Pad value is the default one.
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
# consistent. We can't have flags to skip copies or we'll end up
# recompiling.
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
fill_slice(input_batch.temperature_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
# fill_slice(input_batch.top_p_cpu_tensor)
# fill_slice(input_batch.top_k_cpu_tensor)
fill_slice(input_batch.min_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["min_p"])
xm.mark_step()
xm.wait_device_ops()
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls(
temperature=input_batch.temperature[:padded_num_reqs],
# Scalar tensor for xla-friendly tracing.
all_greedy=torch.tensor(input_batch.all_greedy,
dtype=torch.bool,
device=input_batch.device),
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
to(xla_device),
all_greedy=input_batch.all_greedy,
# TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
min_p=input_batch.min_p[:padded_num_reqs],
generators=input_batch.generators,
indices_do_sample=indices_do_sample)
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device))
# SPDX-License-Identifier: Apache-2.0
import pickle
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional
from typing import Any, Optional, Union
import cloudpickle
import numpy as np
import torch
import zmq
from msgspec import msgpack
CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3
# TODO calibrate this size
MIN_NOCOPY_BUF_SIZE = 512
class MsgpackEncoder:
"""Encoder with custom torch tensor serialization."""
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
def encode(self, obj: Any) -> bytes:
return self.encoder.encode(obj)
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
def encode_into(self, obj: Any, buf: bytearray) -> None:
self.encoder.encode_into(obj, buf)
Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
self.aux_buffers = bufs = [b'']
bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return bufs
finally:
self.aux_buffers = None
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try:
self.aux_buffers = [buf]
bufs = self.aux_buffers
self.encoder.encode_into(obj, buf)
return bufs
finally:
self.aux_buffers = None
def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())
# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)
if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
def _encode_ndarray(
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr_data)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
def decode(self, obj: Any):
return self.decoder.decode(obj)
def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_TENSOR:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported")
self.decoder = msgpack.Decoder(*args,
ext_hook=self.ext_hook,
dec_hook=self.dec_hook)
self.aux_buffers: Sequence[bytestr] = ()
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
# TODO - This check can become `isinstance(bufs, bytestr)`
# as of Python 3.10.
return self.decoder.decode(bufs)
self.aux_buffers = bufs
try:
return self.decoder.decode(bufs[0])
finally:
self.aux_buffers = ()
def dec_hook(self, t: type, obj: Any) -> Any:
# Given native types in `obj`, convert to type `t`.
if isclass(t):
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
return obj
def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(
f"Extension type code {code} is not supported")
......@@ -4,8 +4,11 @@ import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
......@@ -21,8 +24,12 @@ class EagleProposer:
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.block_size = vllm_config.cache_config.block_size
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
device=device,
dtype=torch.int32)
def propose(
self,
......@@ -54,7 +61,9 @@ class EagleProposer:
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids
seq_lens = target_positions[last_token_indices] + 1
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
......@@ -98,7 +107,7 @@ class EagleProposer:
hidden_states = sample_hidden_states
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size]
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
......@@ -176,26 +185,28 @@ class EagleProposer:
return cu_num_tokens, token_indices
def load_model(self, target_model: nn.Module) -> None:
self.model = DummyEagleModel()
self.model.get_input_embeddings = target_model.get_input_embeddings
self.model.compute_logits = target_model.compute_logits
# FIXME(woosuk): This is a dummy model for testing.
# Remove this once we have a real model.
class DummyEagleModel(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
input_embeddings = self.get_input_embeddings(input_ids)
return hidden_states + input_embeddings # Dummy return.
loader = get_model_loader(self.vllm_config.load_config)
target_layer_num = self.vllm_config.model_config.get_num_layers(
self.vllm_config.parallel_config)
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
# FIXME(lily): This does not handle with distributed inference.
target_device = self.vllm_config.device_config.device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = EagleLlamaForCausalLM(
model_config=draft_model_config,
start_layer_id=target_layer_num).to(target_device)
self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
self.model.lm_head = target_model.lm_head
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
......
......@@ -46,7 +46,8 @@ class GuidanceBackend(StructuredOutputBackend):
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(
tokenizer, self.vocab_size)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
......@@ -163,7 +164,6 @@ def validate_guidance_grammar(
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
tp, grm = get_structured_output_key(sampling_params)
guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm,
tokenizer=tokenizer)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
if err:
raise ValueError(f"Grammar error: {err}")
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
......@@ -76,7 +77,12 @@ class XgrammarBackend(StructuredOutputBackend):
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
self.compiler = xgr.GrammarCompiler(
tokenizer_info,
max_threads=8,
cache_enabled=True,
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
......
......@@ -41,8 +41,7 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
return True
# Unsupported keywords for strings
if obj.get("type") == "string" and any(
key in obj for key in ("minLength", "maxLength", "format")):
if obj.get("type") == "string" and "format" in obj:
return True
# Unsupported keywords for objects
......
# SPDX-License-Identifier: Apache-2.0
import multiprocessing
import os
import weakref
from collections import defaultdict
from collections.abc import Sequence
from multiprocessing import Process
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
......@@ -105,28 +105,22 @@ class BackgroundProcHandle:
process_kwargs: dict[Any, Any],
):
context = get_mp_context()
self.reader, writer = context.Pipe(duplex=False)
assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
assert ("input_path" not in process_kwargs
and "output_path" not in process_kwargs)
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path
# Run busy loop in background process.
self.proc = context.Process(target=target_fn,
kwargs=process_kwargs,
name=process_name)
self.proc: Process = context.Process(target=target_fn,
kwargs=process_kwargs,
name=process_name)
self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path)
self.proc.start()
def wait_for_startup(self):
# Wait for startup.
if self.reader.recv()["status"] != "READY":
raise RuntimeError(f"{self.proc.name} initialization failed. "
"See root cause above.")
def fileno(self):
return self.proc.sentinel
def shutdown(self):
self._finalizer()
......@@ -134,7 +128,7 @@ class BackgroundProcHandle:
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
def shutdown(proc: Process, input_path: str, output_path: str):
# Shutdown the process.
if proc.is_alive():
proc.terminate()
......@@ -206,4 +200,4 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor.
"""
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
\ No newline at end of file
......@@ -19,7 +19,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
......@@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import sanity_check_mm_encoder_outputs
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING:
import xgrammar as xgr
......@@ -482,14 +484,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table.commit(num_reqs)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
max_num_scheduled_tokens = 0
for i, req_id in enumerate(self.input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens[i] = num_tokens
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
......@@ -830,19 +828,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = []
req_input_ids: list[tuple[str, int]] = []
mm_inputs = list[MultiModalKwargs]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
......@@ -878,16 +878,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs.append(output)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = []
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
......@@ -895,8 +902,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
......@@ -918,8 +925,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
def get_model(self) -> nn.Module:
return self.model
......@@ -979,15 +994,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) -> Union[ModelRunnerOutput, torch.Tensor]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
encoder_outputs = []
mm_embeds = []
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
......@@ -1009,9 +1024,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
if encoder_outputs:
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, encoder_outputs)
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
......@@ -1172,9 +1187,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions
target_hidden_states = hidden_states
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
......
......@@ -15,13 +15,14 @@ import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
......@@ -30,13 +31,14 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from .utils import sanity_check_mm_encoder_outputs
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -174,10 +176,12 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_paddings(
self.num_tokens_paddings = _get_token_paddings(
min_token_size=16,
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
......@@ -262,11 +266,6 @@ class TPUModelRunner:
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
......@@ -275,7 +274,7 @@ class TPUModelRunner:
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
generator=None,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
......@@ -505,21 +504,48 @@ class TPUModelRunner:
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices
return attn_metadata, logits_indices, padded_num_reqs
def _scatter_placeholders(
self,
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return embeds
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def _gather_placeholders(
self,
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return placeholders
return placeholders[is_embed]
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = []
req_input_ids: list[tuple[str, int]] = []
mm_inputs = list[MultiModalKwargs]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
......@@ -555,16 +581,23 @@ class TPUModelRunner:
encoder_outputs.append(output)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = []
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
......@@ -572,8 +605,8 @@ class TPUModelRunner:
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
......@@ -595,8 +628,16 @@ class TPUModelRunner:
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
@torch.no_grad()
def execute_model(
......@@ -607,25 +648,26 @@ class TPUModelRunner:
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
encoder_outputs = []
mm_embeds = []
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
scheduler_output)
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
if encoder_outputs:
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
self.input_ids, encoder_outputs)
self.input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
input_ids = None
......@@ -637,21 +679,19 @@ class TPUModelRunner:
input_ids = self.input_ids
inputs_embeds = None
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
hidden_states = self.select_hidden_states(hidden_states,
logits_indices)
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, padded_num_reqs, self.device)
selected_token_ids = self.sample_from_hidden(hidden_states,
tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
......@@ -751,17 +791,15 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
# Sync all pending XLA execution during model initialization and weight
# loading.
xm.mark_step()
xm.wait_device_ops()
model = ModelWrapperV1(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
self.model = model
self.sampler = TPUSampler()
@torch.no_grad()
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
def _dummy_run(self, num_tokens: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
......@@ -812,65 +850,81 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0):
out = self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
def capture_model(self) -> None:
"""Compile the model."""
def _precompile_backbone(self) -> None:
logger.info("Compiling the model with different input shapes.")
start = time.perf_counter()
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step()
self._dummy_run(num_tokens)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("model")
self._update_num_xla_graphs("model backbone")
logger.info("Compiling sampling with different input shapes.")
def _precompile_select_hidden_states(self) -> None:
# Compile hidden state selection function for bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
logger.info(
"Compiling select_hidden_states with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
for num_tokens in self.num_tokens_paddings:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dummy_hidden = torch.zeros((num_tokens, hsize),
device=self.device,
dtype=self._hidden_states_dtype)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while True:
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
xm.mark_step()
sampling_meta = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
out = self.model.sample_from_hidden(dummy_hidden,
sampling_meta)
out = out.cpu()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
num_reqs_to_sample + 1, self.max_num_reqs)
torch._dynamo.mark_dynamic(dummy_hidden, 0)
for num_reqs in self.num_reqs_paddings:
indices = torch.zeros(num_reqs,
dtype=torch.int32,
device=self.device)
torch._dynamo.mark_dynamic(indices, 0)
self.select_hidden_states(dummy_hidden, indices)
logger.info(" -- num_tokens: %d", num_tokens)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states")
def _precompile_sample_from_hidden(self) -> None:
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_reqs in self.num_reqs_paddings:
dummy_hidden = torch.zeros((num_reqs, hsize),
device=self.device,
dtype=self._hidden_states_dtype)
# The first dimension of dummy_hidden cannot be mark_dynamic because
# some operations in the sampler require it to be static.
for all_greedy in [False, True]:
generate_params_if_all_greedy = not all_greedy
sampling_metadata = (
TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch,
num_reqs,
self.device,
generate_params_if_all_greedy,
))
sampling_metadata.all_greedy = all_greedy
self.sample_from_hidden(dummy_hidden, sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("sampling")
def capture_model(self) -> None:
"""
Precompile all the subgraphs with possible input shapes.
"""
# TODO: precompile encoder
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_sample_from_hidden()
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
......@@ -910,73 +964,39 @@ class TPUModelRunner:
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.sampler = TPUSampler()
def sample(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return hidden_states
def reset_dynamo_cache(self):
if self.is_multimodal_model:
compiled_model = self.model.get_language_model().model
else:
compiled_model = self.model.model
if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
logger.info("Clear dynamo cache and cached dynamo bytecode.")
torch._dynamo.eval_frame.remove_from_cache(
compiled_model.original_code_object)
compiled_model.compiled_codes.clear()
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def select_hidden_states(self, hidden_states, indices_do_sample):
return hidden_states[indices_do_sample]
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_hidden(
self,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)
logits = self.model.compute_logits(sample_hidden_states, None)
if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
else:
out_tokens = self.sampler(logits,
sampling_metadata).sampled_token_ids
return out_tokens
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
......@@ -984,17 +1004,26 @@ class ModelWrapperV1(nn.Module):
return self.model.get_input_embeddings(*args, **kwargs)
def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
logger.info("Preparing request paddings:")
# assert min_req_size is power of 2
assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
paddings: list = []
num = max(MIN_NUM_SEQS, min_req_size)
while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
paddings.append(num)
logger.info(" %d", num)
num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
return paddings
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)
def _get_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
def _get_token_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
......@@ -1004,18 +1033,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
first increase the size to twice,
then increase the padding size by padding_gap.
"""
# assert min_token_size is power of 2
assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
paddings = []
num = min_token_size
if padding_gap == 0:
logger.info("Using exponential paddings:")
logger.info("Using exponential token paddings:")
while num <= max_token_size:
logger.info(" %d", num)
paddings.append(num)
num *= 2
else:
logger.info("Using incremental paddings:")
logger.info("Using incremental token paddings:")
while num <= padding_gap:
logger.info(" %d", num)
paddings.append(num)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment