Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
......@@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__)
......@@ -402,6 +402,36 @@ class MPClient(EngineCoreClient):
self.utility_results: dict[int, AnyFuture] = {}
def _wait_for_engine_startup(self):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket)
# Wait for engine core process(es) to send ready messages.
identities = set(eng.index for eng in self.resources.core_engines)
poller = zmq.Poller()
poller.register(sync_input_socket, zmq.POLLIN)
for eng in self.resources.core_engines:
poller.register(eng.proc_handle, zmq.POLLIN)
while identities:
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
logger.debug("Waiting for %d core engine proc(s) to start: %s",
len(identities), identities)
continue
if len(events) > 1 or events[0][0] != sync_input_socket:
# One of the core processes exited.
raise RuntimeError("Engine core initialization failed. "
"See root cause above.")
eng_id_bytes, msg = sync_input_socket.recv_multipart()
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
if eng_id not in identities:
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
if msg != b'READY':
raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}")
logger.info("Core engine process %d ready.", eng_id)
identities.discard(eng_id)
def _init_core_engines(
self,
vllm_config: VllmConfig,
......@@ -472,8 +502,8 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
break
frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer)
frames = out_socket.recv_multipart(copy=False)
outputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
......@@ -494,10 +524,10 @@ class SyncMPClient(MPClient):
return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
self.core_engine.send_multipart(msg)
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
*self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future()
......@@ -599,8 +629,8 @@ class AsyncMPClient(MPClient):
async def process_outputs_socket():
while True:
(frame, ) = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
frames = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
......@@ -625,12 +655,20 @@ class AsyncMPClient(MPClient):
assert self.outputs_queue is not None
return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
def _send_input(self,
request_type: EngineCoreRequestType,
request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
if engine is None:
engine = self.core_engine
self._ensure_output_queue_task()
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine)
def _send_input_message(self, message: tuple[bytestr, ...],
engine: CoreEngine) -> Awaitable[None]:
message = (engine.identity, ) + message
return self.input_socket.send_multipart(message, copy=False)
async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method,
......@@ -646,9 +684,9 @@ class AsyncMPClient(MPClient):
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args)))
await self._send_input_message(message, engine)
self._ensure_output_queue_task()
return await future
......@@ -721,7 +759,7 @@ class DPAsyncMPClient(AsyncMPClient):
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
*self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
......@@ -755,7 +793,7 @@ class DPAsyncMPClient(AsyncMPClient):
# tokenized.
request.prompt = None
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache
from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
......@@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
......@@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
# variable VLLM_MM_INPUT_CACHE_GIB.
class MMInputCacheServer:
class MirroredProcessingCache:
def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def get_and_update(
def get_and_update_p0(
self,
mm_inputs: list[MultiModalKwargs],
mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str],
) -> list[MultiModalKwargs]:
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = []
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_hash in self.mm_cache:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def get_and_update_p1(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
else:
......
# SPDX-License-Identifier: Apache-2.0
import time
from collections.abc import Mapping
from typing import Optional, Union
from collections.abc import Mapping, Sequence
from typing import Literal, Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
......@@ -46,6 +48,8 @@ class Processor:
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
......@@ -73,6 +77,7 @@ class Processor:
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
......@@ -83,6 +88,26 @@ class Processor:
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
def _validate_supported_sampling_params(
self,
params: SamplingParams,
......@@ -136,9 +161,6 @@ class Processor:
f" != {engine_level_backend}")
else:
params.guided_decoding.backend = engine_level_backend
import vllm.platforms
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
# Request content validation
if engine_level_backend.startswith("xgrammar"):
......@@ -181,6 +203,11 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
)
self._validate_lora(lora_request)
self._validate_params(params)
if priority != 0:
......@@ -228,7 +255,7 @@ class Processor:
self.tokenizer.get_lora_tokenizer(lora_request))
# Multimodal related.
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
......@@ -253,20 +280,28 @@ class Processor:
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
sorted_mm_inputs = []
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
]))
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
sorted_mm_inputs = [
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return EngineCoreRequest(
request_id=request_id,
prompt=decoder_inputs.get("prompt"),
......@@ -285,41 +320,64 @@ class Processor:
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
max_input_id = max(prompt_ids)
max_allowed = self.tokenizer.get_lora_tokenizer(
lora_request).max_token_id
if max_input_id > max_allowed:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
if len(prompt_ids) >= self.model_config.max_model_len:
raise ValueError(
f"Prompt length of {len(prompt_ids)} is longer than the "
f"maximum model length of {self.model_config.max_model_len}.")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
if model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
......
......@@ -119,10 +119,9 @@ class MultiprocExecutor(Executor):
timeout=dequeue_timeout)
if status != WorkerProc.ResponseStatus.SUCCESS:
if isinstance(result, Exception):
raise result
else:
raise RuntimeError("Worker failed")
raise RuntimeError(
"Worker failed with error %s, please check the"
" stack trace above for the root cause", result)
responses[w.rank] = result
......@@ -327,7 +326,7 @@ class WorkerProc:
logger.debug("Worker interrupted.")
except Exception:
# worker_busy_loop sends exceptions exceptons to Executor
# worker_busy_loop sends exceptions to Executor
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
psutil.Process().parent().send_signal(signal.SIGUSR1)
......@@ -378,9 +377,11 @@ class WorkerProc:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e))
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
self.worker_response_mq.enqueue(
......
......@@ -239,7 +239,8 @@ class PrometheusStatLogger(StatLoggerBase):
documentation="Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
640.0, 2560.0
],
labelnames=labelnames).labels(*labelvalues)
......@@ -249,13 +250,13 @@ class PrometheusStatLogger(StatLoggerBase):
documentation="Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,
0.75, 1.0, 2.5
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
],
labelnames=labelnames).labels(*labelvalues)
request_latency_buckets = [
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
]
self.histogram_e2e_time_request = \
prometheus_client.Histogram(
......
......@@ -3,17 +3,16 @@
import enum
from typing import TYPE_CHECKING, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
class Request:
......@@ -23,9 +22,9 @@ class Request:
request_id: str,
prompt: Optional[str],
prompt_token_ids: list[int],
multi_modal_inputs: Optional[list["MultiModalKwargs"]],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list["PlaceholderRange"]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
......@@ -75,6 +74,11 @@ class Request:
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None:
assert isinstance(request.mm_inputs, list)
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
"mm_inputs was not updated in EngineCore.add_request")
return cls(
request_id=request.request_id,
prompt=request.prompt,
......@@ -121,7 +125,7 @@ class Request:
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"]
num_tokens = self.mm_positions[input_id].length
return num_tokens
@property
......
......@@ -230,9 +230,19 @@ class Sampler(nn.Module):
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
# Get vocabulary size from logits
vocab_size = logits.shape[-1]
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
# Check token_id bounds to ensure within vocabulary
if token_id < 0 or token_id >= vocab_size:
raise ValueError(
f"token_id {token_id} in logit_bias contains "
f"out-of-vocab token id. Vocabulary size: "
f"{vocab_size}")
logits[i, token_id] += bias
return logits
......
......@@ -3,7 +3,6 @@ from dataclasses import dataclass, field
from typing import Optional
import torch
import torch_xla.core.xla_model as xm
from vllm.v1.worker.gpu_input_batch import InputBatch
......@@ -24,19 +23,15 @@ class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor
temperature: torch.Tensor = None
min_p: torch.Tensor
min_p: torch.Tensor = None
# Still too slow on forward_native!
top_k: torch.Tensor = None
top_p: torch.Tensor = None
# Greedy sampling flag for compiling single xla graph.
all_greedy: torch.Tensor = None
# Generator not supported by xla
generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
all_greedy: bool = True
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None
......@@ -57,64 +52,66 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask = None
bad_words_token_ids = None
indices_do_sample: torch.Tensor = None
# Generator not supported by xla
_generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
@property
def generators(self) -> dict[int, torch.Generator]:
# Generator not supported by torch/xla. This field must be immutable.
return self._generators
@classmethod
def from_input_batch(
cls, input_batch: InputBatch,
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
cls,
input_batch: InputBatch,
padded_num_reqs: int,
xla_device: torch.device,
generate_params_if_all_greedy: bool = False
) -> "TPUSupportedSamplingMetadata":
"""
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
We expect sampling params tensors to be padded to the same fixed shape.
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
Args:
input_batch: The input batch containing sampling parameters.
padded_num_reqs: The padded number of requests.
xla_device: The XLA device.
generate_params_if_all_greedy: If True, generate sampling parameters
even if all requests are greedy. this is useful for cases where
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
# Early return to avoid unnecessary cpu to tpu copy
if (input_batch.all_greedy is True
and generate_params_if_all_greedy is False):
return cls(all_greedy=True)
num_reqs = input_batch.num_reqs
padded_num_reqs = len(indices_do_sample)
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
fill_val) -> torch.Tensor:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
# Pad value is the default one.
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
# consistent. We can't have flags to skip copies or we'll end up
# recompiling.
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
fill_slice(input_batch.temperature_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
# fill_slice(input_batch.top_p_cpu_tensor)
# fill_slice(input_batch.top_k_cpu_tensor)
fill_slice(input_batch.min_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["min_p"])
xm.mark_step()
xm.wait_device_ops()
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls(
temperature=input_batch.temperature[:padded_num_reqs],
# Scalar tensor for xla-friendly tracing.
all_greedy=torch.tensor(input_batch.all_greedy,
dtype=torch.bool,
device=input_batch.device),
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
to(xla_device),
all_greedy=input_batch.all_greedy,
# TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
min_p=input_batch.min_p[:padded_num_reqs],
generators=input_batch.generators,
indices_do_sample=indices_do_sample)
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device))
# SPDX-License-Identifier: Apache-2.0
import pickle
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional
from typing import Any, Optional, Union
import cloudpickle
import numpy as np
import torch
import zmq
from msgspec import msgpack
CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3
# TODO calibrate this size
MIN_NOCOPY_BUF_SIZE = 512
class MsgpackEncoder:
"""Encoder with custom torch tensor serialization."""
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
def encode(self, obj: Any) -> bytes:
return self.encoder.encode(obj)
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
def encode_into(self, obj: Any, buf: bytearray) -> None:
self.encoder.encode_into(obj, buf)
Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
self.aux_buffers = bufs = [b'']
bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return bufs
finally:
self.aux_buffers = None
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try:
self.aux_buffers = [buf]
bufs = self.aux_buffers
self.encoder.encode_into(obj, buf)
return bufs
finally:
self.aux_buffers = None
def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())
# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)
if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
def _encode_ndarray(
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr_data)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
def decode(self, obj: Any):
return self.decoder.decode(obj)
def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_TENSOR:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported")
self.decoder = msgpack.Decoder(*args,
ext_hook=self.ext_hook,
dec_hook=self.dec_hook)
self.aux_buffers: Sequence[bytestr] = ()
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
# TODO - This check can become `isinstance(bufs, bytestr)`
# as of Python 3.10.
return self.decoder.decode(bufs)
self.aux_buffers = bufs
try:
return self.decoder.decode(bufs[0])
finally:
self.aux_buffers = ()
def dec_hook(self, t: type, obj: Any) -> Any:
# Given native types in `obj`, convert to type `t`.
if isclass(t):
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
return obj
def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(
f"Extension type code {code} is not supported")
......@@ -4,8 +4,11 @@ import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
......@@ -21,8 +24,12 @@ class EagleProposer:
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.block_size = vllm_config.cache_config.block_size
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
device=device,
dtype=torch.int32)
def propose(
self,
......@@ -54,7 +61,9 @@ class EagleProposer:
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids
seq_lens = target_positions[last_token_indices] + 1
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
......@@ -98,7 +107,7 @@ class EagleProposer:
hidden_states = sample_hidden_states
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size]
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
......@@ -176,26 +185,28 @@ class EagleProposer:
return cu_num_tokens, token_indices
def load_model(self, target_model: nn.Module) -> None:
self.model = DummyEagleModel()
self.model.get_input_embeddings = target_model.get_input_embeddings
self.model.compute_logits = target_model.compute_logits
# FIXME(woosuk): This is a dummy model for testing.
# Remove this once we have a real model.
class DummyEagleModel(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
input_embeddings = self.get_input_embeddings(input_ids)
return hidden_states + input_embeddings # Dummy return.
loader = get_model_loader(self.vllm_config.load_config)
target_layer_num = self.vllm_config.model_config.get_num_layers(
self.vllm_config.parallel_config)
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
# FIXME(lily): This does not handle with distributed inference.
target_device = self.vllm_config.device_config.device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = EagleLlamaForCausalLM(
model_config=draft_model_config,
start_layer_id=target_layer_num).to(target_device)
self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
self.model.lm_head = target_model.lm_head
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
......
......@@ -46,7 +46,8 @@ class GuidanceBackend(StructuredOutputBackend):
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(
tokenizer, self.vocab_size)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
......@@ -163,7 +164,6 @@ def validate_guidance_grammar(
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
tp, grm = get_structured_output_key(sampling_params)
guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm,
tokenizer=tokenizer)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
if err:
raise ValueError(f"Grammar error: {err}")
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
......@@ -76,7 +77,12 @@ class XgrammarBackend(StructuredOutputBackend):
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
self.compiler = xgr.GrammarCompiler(
tokenizer_info,
max_threads=8,
cache_enabled=True,
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
......
......@@ -41,8 +41,7 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
return True
# Unsupported keywords for strings
if obj.get("type") == "string" and any(
key in obj for key in ("minLength", "maxLength", "format")):
if obj.get("type") == "string" and "format" in obj:
return True
# Unsupported keywords for objects
......
# SPDX-License-Identifier: Apache-2.0
import multiprocessing
import os
import weakref
from collections import defaultdict
from collections.abc import Sequence
from multiprocessing import Process
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
......@@ -115,18 +115,15 @@ class BackgroundProcHandle:
process_kwargs["output_path"] = output_path
# Run busy loop in background process.
self.proc = context.Process(target=target_fn,
kwargs=process_kwargs,
name=process_name)
self.proc: Process = context.Process(target=target_fn,
kwargs=process_kwargs,
name=process_name)
self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path)
self.proc.start()
def wait_for_startup(self):
# Wait for startup.
if self.reader.recv()["status"] != "READY":
raise RuntimeError(f"{self.proc.name} initialization failed. "
"See root cause above.")
def fileno(self):
return self.proc.sentinel
def shutdown(self):
self._finalizer()
......@@ -134,7 +131,7 @@ class BackgroundProcHandle:
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
def shutdown(proc: Process, input_path: str, output_path: str):
# Shutdown the process.
if proc.is_alive():
proc.terminate()
......
......@@ -19,7 +19,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
......@@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import sanity_check_mm_encoder_outputs
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING:
import xgrammar as xgr
......@@ -482,14 +484,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table.commit(num_reqs)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
max_num_scheduled_tokens = 0
for i, req_id in enumerate(self.input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens[i] = num_tokens
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
......@@ -830,19 +828,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = []
req_input_ids: list[tuple[str, int]] = []
mm_inputs = list[MultiModalKwargs]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
......@@ -878,16 +878,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs.append(output)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = []
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
......@@ -895,8 +902,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
......@@ -918,8 +925,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
def get_model(self) -> nn.Module:
return self.model
......@@ -979,15 +994,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) -> Union[ModelRunnerOutput, torch.Tensor]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
encoder_outputs = []
mm_embeds = []
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
......@@ -1009,9 +1024,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
if encoder_outputs:
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, encoder_outputs)
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
......@@ -1172,9 +1187,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions
target_hidden_states = hidden_states
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
......
......@@ -15,13 +15,14 @@ import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
......@@ -30,13 +31,14 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from .utils import sanity_check_mm_encoder_outputs
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -174,10 +176,12 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_paddings(
self.num_tokens_paddings = _get_token_paddings(
min_token_size=16,
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
......@@ -262,11 +266,6 @@ class TPUModelRunner:
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
......@@ -275,7 +274,7 @@ class TPUModelRunner:
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
generator=None,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
......@@ -505,21 +504,48 @@ class TPUModelRunner:
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices
return attn_metadata, logits_indices, padded_num_reqs
def _scatter_placeholders(
self,
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return embeds
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def _gather_placeholders(
self,
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return placeholders
return placeholders[is_embed]
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = []
req_input_ids: list[tuple[str, int]] = []
mm_inputs = list[MultiModalKwargs]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
......@@ -555,16 +581,23 @@ class TPUModelRunner:
encoder_outputs.append(output)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = []
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
......@@ -572,8 +605,8 @@ class TPUModelRunner:
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
......@@ -595,8 +628,16 @@ class TPUModelRunner:
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
@torch.no_grad()
def execute_model(
......@@ -607,25 +648,26 @@ class TPUModelRunner:
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
encoder_outputs = []
mm_embeds = []
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
scheduler_output)
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
if encoder_outputs:
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
self.input_ids, encoder_outputs)
self.input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
input_ids = None
......@@ -637,21 +679,19 @@ class TPUModelRunner:
input_ids = self.input_ids
inputs_embeds = None
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
hidden_states = self.select_hidden_states(hidden_states,
logits_indices)
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, padded_num_reqs, self.device)
selected_token_ids = self.sample_from_hidden(hidden_states,
tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
......@@ -751,17 +791,15 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
# Sync all pending XLA execution during model initialization and weight
# loading.
xm.mark_step()
xm.wait_device_ops()
model = ModelWrapperV1(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
self.model = model
self.sampler = TPUSampler()
@torch.no_grad()
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
def _dummy_run(self, num_tokens: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
......@@ -812,65 +850,81 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0):
out = self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
def capture_model(self) -> None:
"""Compile the model."""
def _precompile_backbone(self) -> None:
logger.info("Compiling the model with different input shapes.")
start = time.perf_counter()
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step()
self._dummy_run(num_tokens)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("model")
self._update_num_xla_graphs("model backbone")
logger.info("Compiling sampling with different input shapes.")
def _precompile_select_hidden_states(self) -> None:
# Compile hidden state selection function for bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
logger.info(
"Compiling select_hidden_states with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
for num_tokens in self.num_tokens_paddings:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dummy_hidden = torch.zeros((num_tokens, hsize),
device=self.device,
dtype=self._hidden_states_dtype)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while True:
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
xm.mark_step()
sampling_meta = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
out = self.model.sample_from_hidden(dummy_hidden,
sampling_meta)
out = out.cpu()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
num_reqs_to_sample + 1, self.max_num_reqs)
torch._dynamo.mark_dynamic(dummy_hidden, 0)
for num_reqs in self.num_reqs_paddings:
indices = torch.zeros(num_reqs,
dtype=torch.int32,
device=self.device)
torch._dynamo.mark_dynamic(indices, 0)
self.select_hidden_states(dummy_hidden, indices)
logger.info(" -- num_tokens: %d", num_tokens)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states")
def _precompile_sample_from_hidden(self) -> None:
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_reqs in self.num_reqs_paddings:
dummy_hidden = torch.zeros((num_reqs, hsize),
device=self.device,
dtype=self._hidden_states_dtype)
# The first dimension of dummy_hidden cannot be mark_dynamic because
# some operations in the sampler require it to be static.
for all_greedy in [False, True]:
generate_params_if_all_greedy = not all_greedy
sampling_metadata = (
TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch,
num_reqs,
self.device,
generate_params_if_all_greedy,
))
sampling_metadata.all_greedy = all_greedy
self.sample_from_hidden(dummy_hidden, sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("sampling")
def capture_model(self) -> None:
"""
Precompile all the subgraphs with possible input shapes.
"""
# TODO: precompile encoder
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_sample_from_hidden()
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
......@@ -910,73 +964,39 @@ class TPUModelRunner:
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.sampler = TPUSampler()
def sample(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return hidden_states
def reset_dynamo_cache(self):
if self.is_multimodal_model:
compiled_model = self.model.get_language_model().model
else:
compiled_model = self.model.model
if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
logger.info("Clear dynamo cache and cached dynamo bytecode.")
torch._dynamo.eval_frame.remove_from_cache(
compiled_model.original_code_object)
compiled_model.compiled_codes.clear()
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def select_hidden_states(self, hidden_states, indices_do_sample):
return hidden_states[indices_do_sample]
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_hidden(
self,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)
logits = self.model.compute_logits(sample_hidden_states, None)
if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
else:
out_tokens = self.sampler(logits,
sampling_metadata).sampled_token_ids
return out_tokens
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
......@@ -984,17 +1004,26 @@ class ModelWrapperV1(nn.Module):
return self.model.get_input_embeddings(*args, **kwargs)
def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
logger.info("Preparing request paddings:")
# assert min_req_size is power of 2
assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
paddings: list = []
num = max(MIN_NUM_SEQS, min_req_size)
while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
paddings.append(num)
logger.info(" %d", num)
num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
return paddings
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)
def _get_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
def _get_token_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
......@@ -1004,18 +1033,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
first increase the size to twice,
then increase the padding size by padding_gap.
"""
# assert min_token_size is power of 2
assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
paddings = []
num = min_token_size
if padding_gap == 0:
logger.info("Using exponential paddings:")
logger.info("Using exponential token paddings:")
while num <= max_token_size:
logger.info(" %d", num)
paddings.append(num)
num *= 2
else:
logger.info("Using incremental paddings:")
logger.info("Using incremental token paddings:")
while num <= padding_gap:
logger.info(" %d", num)
paddings.append(num)
......
......@@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches)
self.model_runner._dummy_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
self.scheduler_config.max_num_batched_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
......@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if is_embed is None:
return placeholders
return placeholders[is_embed]
......@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
......@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger = init_logger(__name__)
LORA_WARMUP_RANK = 8
@dataclasses.dataclass(frozen=True)
......@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in "
"EncoderDecoderModelRunner")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if (model_input.attn_metadata is not None
and model_input.attn_metadata.prefill_metadata is None
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
......@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config:
dummy_lora_requests = self._add_dummy_loras(
self.lora_config.max_loras)
assert len(dummy_lora_requests) == self.lora_config.max_loras
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
......@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
block_tables=None,
encoder_seq_data=encoder_dummy_data.seq_data,
cross_block_table=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=decoder_dummy_data.multi_modal_data
or encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=decoder_dummy_data.
......
......@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
......@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
make_tensor_with_pad)
from vllm.worker.model_runner_base import (
......@@ -100,7 +103,10 @@ def subtuple(obj: object,
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if type(obj) is dict:
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename,
' '.join(fields))
......@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
virtual_engine: int = 0
lora_ids: Optional[List[int]] = None
async_callback: Optional[Callable] = None
is_first_multi_step: bool = True
is_last_step: bool = True
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
......@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
"batch_size_padded": self.batch_size_padded,
"virtual_engine": self.virtual_engine,
"lora_ids": self.lora_ids,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
......@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self._set_gc_threshold()
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations.
......@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
output=None,
) -> PrepareDecodeMetadata:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
......@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
if output is None:
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
......@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
num_fully_occupied_blocks = position // self.block_size
block_table = block_table[:num_fully_occupied_blocks + 1]
if len(block_table) == 0:
block_number = _PAD_BLOCK_ID
else:
......@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
if output is None:
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
else:
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size]
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
......@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, None, warmup_mode=True)
is_single_step = \
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_prompt or is_single_step:
self.execute_model(inputs, None, warmup_mode=True)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
is_last_step=False)
self.execute_model(inputs,
None,
warmup_mode=True,
num_steps=2,
seqs=seqs)
inputs = dataclasses.replace(inputs,
is_first_multi_step=False,
is_last_step=True)
self.execute_model(inputs,
None,
warmup_mode=True,
num_steps=2,
seqs=seqs)
torch.hpu.synchronize()
if profiler:
profiler.step()
......@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
warmup_mode=False,
seqs=None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"num_steps > 1 is not supported in HPUModelRunner")
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
# not first or last multi-step
return []
# last multi-step
output = self._decode_sampler_outputs(
model_input) if self.is_driver_worker else []
torch.hpu.synchronize()
if model_input.is_first_multi_step:
# first multi-step
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
sampling_metadata = model_input.sampling_metadata
real_batch_size = model_input.real_batch_size
batch_size_padded = model_input.batch_size_padded
assert input_tokens is not None
assert input_positions is not None
assert sampling_metadata is not None
assert attn_metadata is not None
is_prompt = attn_metadata.is_prompt
assert is_prompt is not None
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
if self.lora_config:
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids,
attn_metadata.is_prompt)
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}),
}
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update(
{"bypass_hpu_graphs": not use_graphs})
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
sampling_metadata = model_input.sampling_metadata
real_batch_size = model_input.real_batch_size
batch_size_padded = model_input.batch_size_padded
assert input_tokens is not None
assert input_positions is not None
assert sampling_metadata is not None
assert attn_metadata is not None
is_prompt = attn_metadata.is_prompt
assert is_prompt is not None
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
htorch.core.mark_step()
if self.is_driver_worker:
model_event_name = ("model_"
f"{'prompt' if is_prompt else 'decode'}_"
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
if num_steps > 1:
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True
self.model.model.sampler.include_gpu_probs_tensor = True
cache_orig_output_tokens_len: List[Dict] = []
def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
data.output_token_ids = \
data.output_token_ids[:orig_output_tokens_len]
for i in range(num_steps):
if i != 0 and not self.is_driver_worker:
broadcast_data = broadcast_tensor_dict(src=0)
if 'early_exit' in broadcast_data and broadcast_data[
'early_exit']:
return [output] if num_steps == 1 else []
execute_model_kwargs.update({
"input_ids":
broadcast_data["input_ids"],
"positions":
broadcast_data["positions"],
"attn_metadata":
self.trim_attn_metadata(
broadcast_data["attn_metadata"])
})
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.
selected_token_indices)
if self.lora_config:
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))
# Compute the logits.
with self.profiler.record_event(
'internal',
('compute_logits_'
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
if num_steps == 1:
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
continue
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None
if self.lora_config:
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids, attn_metadata.is_prompt)
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}),
}
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs})
htorch.core.mark_step()
if self.is_driver_worker:
model_event_name = ("model_"
f"{'prompt' if is_prompt else 'decode'}_"
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
if num_steps > 1:
output = output.sampled_token_ids
self.cached_step_outputs.append(
output.detach().clone())
htorch.core.mark_step()
if i < num_steps - 1:
if i == 0:
if model_input.async_callback is not None:
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
seq_group_metadata_list = \
ctx.seq_group_metadata_list
elif seqs is not None:
seq_group_metadata_list = seqs
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[i][j] = \
len(data.output_token_ids)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
0].sampling_params.max_tokens
if len(data.output_token_ids) < max_output_len - 1:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token = (540, )
data.output_token_ids += (dummy_token)
else:
broadcast_tensor_dict({'early_exit': True},
src=0)
if num_steps == 1:
return [output]
else:
try_revert_dummy_output_tokens()
return []
result = self._prepare_decode(seq_group_metadata_list,
output=output)
execute_model_kwargs.update({
"input_ids":
result.input_tokens,
"positions":
result.input_positions,
"attn_metadata":
self.trim_attn_metadata(result.attn_metadata)
})
model_kwargs_broadcast_data = {
"input_ids": result.input_tokens,
"positions": result.input_positions,
"attn_metadata": vars(result.attn_metadata)
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
else:
try_revert_dummy_output_tokens()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
self.profiler.end()
event_end = self.profiler.get_timestamp_us()
counters = self.profiler_counter_helper.get_counter_dict(
cache_config=self.cache_config,
duration=event_end - self.event_start,
seq_len=seq_len,
batch_size_padded=batch_size_padded,
real_batch_size=real_batch_size,
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if num_steps == 1:
return [output] if self.is_driver_worker else []
else:
return []
return output if type(output) is list else [output]
def _decode_sampler_outputs(self, model_input):
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = self._make_decode_output(
next_token_ids, model_input.sampling_metadata.seq_groups)
sampler_outputs.append(sampler_output)
if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=False)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else:
model_event_name = 'model_executable'
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices
)
return sampler_outputs
if self.lora_config:
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))
# Compute the logits.
with self.profiler.record_event(
'internal', ('compute_logits_'
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
output.outputs = output.outputs[:real_batch_size]
htorch.core.mark_step()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
self.profiler.end()
event_end = self.profiler.get_timestamp_us()
counters = self.profiler_counter_helper.get_counter_dict(
cache_config=self.cache_config,
duration=event_end - self.event_start,
seq_len=seq_len,
batch_size_padded=batch_size_padded,
real_batch_size=real_batch_size,
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
return [output]
def _make_decode_output(
self,
next_token_ids: List[List[int]],
seq_groups: List[SequenceGroupToSample],
) -> SamplerOutput:
zero_logprob = Logprob(0.0)
sampler_outputs = []
batch_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
seq_outputs = []
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx][0]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return SamplerOutput(sampler_outputs)
def shutdown_inc(self):
can_finalize_inc = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment