Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
......@@ -4,7 +4,8 @@ import random
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
......@@ -220,10 +221,10 @@ class SchedulerSwappedInOutputs:
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups: List[SequenceGroup]
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are going to be swapped in and in a prefill
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups: List[SequenceGroup]
prefill_seq_groups: List[ScheduledSequenceGroup]
# The blocks to swap in.
blocks_to_swap_in: List[Tuple[int, int]]
# The blocks to copy.
......@@ -253,7 +254,7 @@ class SchedulerPrefillOutputs:
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups: List[SequenceGroup]
seq_groups: List[ScheduledSequenceGroup]
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
......@@ -288,7 +289,9 @@ def scheduler_running_outputs_builder():
def scheduled_seq_group_builder():
return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
return ScheduledSequenceGroup(SequenceGroup("", [], -1),
token_chunk_size=0)
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
class Scheduler:
......@@ -299,6 +302,7 @@ class Scheduler:
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback: Optional[Callable] = None,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
......@@ -364,10 +368,36 @@ class Scheduler:
self.num_cumulative_preemption: int = 0
# Used to cache python objects
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
scheduler_running_outputs_builder)
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
scheduled_seq_group_builder)
self._seq_group_metadata_cache: List[PyObjectCache] = []
self._scheduler_running_outputs_cache: List[PyObjectCache] = []
self._scheduled_seq_group_cache: List[PyObjectCache] = []
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self.output_proc_callback = output_proc_callback
self.use_async_output_proc = self.output_proc_callback is not None
self.num_cache_iters = 2 if self.use_async_output_proc else 1
self.cache_id = 0
for i in range(self.num_cache_iters):
self._seq_group_metadata_cache.append(
PyObjectCache(seq_group_metadata_builder))
self._scheduler_running_outputs_cache.append(
PyObjectCache(scheduler_running_outputs_builder))
self._scheduled_seq_group_cache.append(
PyObjectCache(scheduled_seq_group_builder))
# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []
@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
@property
def lora_enabled(self) -> bool:
......@@ -483,7 +513,7 @@ class Scheduler:
SchedulerRunningOutputs.
"""
ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache.get_object()
self._scheduler_running_outputs_cache[self.cache_id].get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
......@@ -510,8 +540,12 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
running_queue = self.running
# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
......@@ -521,6 +555,28 @@ class Scheduler:
break
running_queue.popleft()
# With async postprocessor, an extra decode run is done
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if self.use_async_output_proc and seq_group.seqs[0].get_len(
) > self.scheduler_config.max_model_len:
self._async_stopped.append(seq_group)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback is not None
self.output_proc_callback()
self.running = tmp
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
......@@ -556,7 +612,7 @@ class Scheduler:
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache.get_object()
self._scheduled_seq_group_cache[self.cache_id].get_object()
scheduled_seq_group.seq_group = seq_group
if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens
......@@ -579,8 +635,8 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
self._scheduler_running_outputs_cache.reset()
self._scheduled_seq_group_cache.reset()
self._scheduler_running_outputs_cache[self.next_cache_id].reset()
self._scheduled_seq_group_cache[self.next_cache_id].reset()
return ret
......@@ -737,7 +793,7 @@ class Scheduler:
SchedulerPrefillOutputs.
"""
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []
waiting_queue = self.waiting
......@@ -971,16 +1027,21 @@ class Scheduler:
# Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
# By default, vLLM scheduler prioritizes prefills.
# Once chunked prefill is enabled,
# the policy is changed to prioritize decode requests.
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
self.running.extend([s.seq_group for s in prefills.seq_groups])
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs(
......@@ -1031,17 +1092,28 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
)
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
no_beam_search = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1
and not seq_group.sampling_params.use_beam_search)
return no_beam_search
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.perf_counter()
scheduler_outputs = self._schedule()
now = time.time()
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
allow_async_output_proc: bool = self.use_async_output_proc
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, scheduled_seq_group in enumerate(
......@@ -1050,6 +1122,11 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache[
self.cache_id].get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
......@@ -1057,7 +1134,9 @@ class Scheduler:
if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
encoder_seq_data = seq_group.get_encoder_seq().data
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
encoder_seq_data = encoder_seq.data
# Block table for cross-attention
# Also managed at SequenceGroup level
cross_block_table = self.block_manager.get_cross_block_table(
......@@ -1139,13 +1218,20 @@ class Scheduler:
)
seq_group_metadata_list.append(seq_group_metadata)
if allow_async_output_proc:
allow_async_output_proc = self._allow_async_output_proc(
seq_group)
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
scheduled_seq_group.seq_group,
scheduled_seq_group.token_chunk_size)
self._seq_group_metadata_cache[self.next_cache_id].reset()
scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
......@@ -1158,7 +1244,12 @@ class Scheduler:
else:
seq_group.metrics.scheduler_time = scheduler_time
return seq_group_metadata_list, scheduler_outputs
# Move to next cache (if exists)
self.cache_id = self.next_cache_id
# Return results
return (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq)
......@@ -1167,6 +1258,12 @@ class Scheduler:
"""Free a sequence from a block table."""
self.block_manager.free(seq)
def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
"""Free finished seqs in a sequence group."""
for seq in seq_group.get_seqs():
if seq.is_finished():
self.free_seq(seq)
def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
......@@ -1179,8 +1276,24 @@ class Scheduler:
self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group)
# Free finished seqs
self._free_finished_seqs(seq_group)
self.running = remaining
# Handle async stopped sequence groups
# (ones that reached max model len)
if self._async_stopped:
for seq_group in self._async_stopped:
self._free_seq_group_cross_attn_blocks(seq_group)
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
self._async_stopped.clear()
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
......@@ -1347,10 +1460,27 @@ class Scheduler:
for seq in seqs:
num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case.
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
if enable_chunking and len(seqs) == 1:
num_new_tokens = min(num_new_tokens,
budget.remaining_token_budget())
remaining_token_budget = budget.remaining_token_budget()
if self.cache_config.enable_prefix_caching:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block size
# to avoid partial block matching.
block_size = self.cache_config.block_size
reminder = budget.token_budget % block_size
if reminder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {reminder}")
if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget //
block_size) * block_size
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens
......@@ -4,6 +4,7 @@ import os
import pickle
import subprocess
import sys
import tempfile
from itertools import product
from typing import Dict, List, Optional, Sequence
......@@ -211,20 +212,27 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
input_bytes = pickle.dumps((batch_src, batch_tgt))
returned = subprocess.run([sys.executable, __file__],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
result = pickle.loads(returned.stdout)
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with tempfile.NamedTemporaryFile() as output_file:
input_bytes = pickle.dumps(
(batch_src, batch_tgt, output_file.name))
returned = subprocess.run([sys.executable, __file__],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f:
......@@ -241,6 +249,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
__all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__":
batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read())
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt)
sys.stdout.buffer.write(pickle.dumps(result))
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))
import os
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
......@@ -5,11 +7,12 @@ from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
if current_platform.is_tpu():
import ray
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from vllm.executor import ray_utils
class TpuCommunicator:
......@@ -24,9 +27,29 @@ class TpuCommunicator:
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes = ray_utils.get_num_tpu_nodes()
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
if num_nodes_in_pg > 0:
num_nodes = num_nodes_in_pg
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
......
......@@ -2,8 +2,8 @@ import argparse
import dataclasses
import json
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
Union)
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
import torch
......@@ -16,6 +16,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import FlexibleArgumentParser
if TYPE_CHECKING:
......@@ -147,6 +148,8 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None
def __post_init__(self):
if self.tokenizer is None:
......@@ -197,10 +200,11 @@ class EngineArgs:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
choices=['auto', 'slow', 'mistral'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer.')
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
......@@ -317,9 +321,10 @@ class EngineArgs:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
choices=[8, 16, 32],
help='Token block size for contiguous chunks of '
'tokens.')
'tokens. This is ignored on neuron devices and '
'set to max-model-len')
parser.add_argument('--enable-prefix-caching',
action='store_true',
......@@ -732,6 +737,22 @@ class EngineArgs:
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.")
parser.add_argument(
'--disable-async-output-proc',
action='store_true',
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
parser.add_argument(
'--override-neuron-config',
type=lambda configs: {
str(key): value
for key, value in
(config.split(':') for config in configs.split(','))
},
default=None,
help="override or set neuron device configuration.")
return parser
@classmethod
......@@ -742,9 +763,9 @@ class EngineArgs:
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_engine_config(self, ) -> EngineConfig:
def create_engine_config(self) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if self.model.endswith(".gguf"):
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
# bitsandbytes quantization needs a specific model loader
......@@ -791,9 +812,11 @@ class EngineArgs:
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
)
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config)
cache_config = CacheConfig(
block_size=self.block_size,
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
......@@ -910,6 +933,7 @@ class EngineArgs:
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
......
import asyncio
import time
from dataclasses import dataclass
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
import torch
from typing_extensions import assert_never
import vllm.envs as envs
......@@ -15,7 +13,7 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents)
PromptComponents, SchedulerOutputState)
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
......@@ -24,12 +22,12 @@ from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
......@@ -257,24 +255,11 @@ class RequestTracker:
return not self._new_requests.empty()
@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pipeline_parallel_size = \
self.parallel_config.pipeline_parallel_size
self.cached_scheduler_outputs = [
SchedulerOutputState() for _ in range(pipeline_parallel_size)
]
async def step_async(
self, virtual_engine: int
......@@ -293,19 +278,37 @@ class _AsyncLLMEngine(LLMEngine):
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
......@@ -333,14 +336,22 @@ class _AsyncLLMEngine(LLMEngine):
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
output = []
# Finish the current step for all the sequence groups.
......@@ -349,77 +360,45 @@ class _AsyncLLMEngine(LLMEngine):
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# Clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))
if output and allow_async_output_proc:
assert len(
output
) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
else:
request_outputs = []
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
return request_outputs
def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))
return ref_remaining_steps > 0
def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None
def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
cached_last_output = self.cached_scheduler_outputs[
virtual_engine].last_output
if (self.scheduler_config.is_multi_step
and self.parallel_config.pipeline_parallel_size > 1
and cached_last_output is not None
and cached_last_output.sampled_token_ids_cpu is not None):
return cached_last_output.sampled_token_ids_cpu
return None
# Multi-step case
return ctx.request_outputs
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
def _update_cached_scheduler_output(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_cpu is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output
return ctx.request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
......@@ -635,6 +614,17 @@ class AsyncLLMEngine:
self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self.use_process_request_outputs_callback = not self.engine_use_ray
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs
if self.engine_use_ray:
print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
......@@ -702,6 +692,11 @@ class AsyncLLMEngine:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
......@@ -873,13 +868,27 @@ class AsyncLLMEngine:
request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams.
finished = True
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if not self.use_process_request_outputs_callback:
all_finished = self.process_request_outputs(request_outputs)
else:
# For callback case, we only need to detect when all
# requests are finished
all_finished = all(request_output.finished
for request_output in request_outputs)
return not all_finished
def process_request_outputs(self, request_outputs) -> bool:
# Put the outputs into the corresponding streams.
all_finished = True
for request_output in request_outputs:
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
finished = finished and request_output.finished
all_finished = all_finished and request_output.finished
return not finished
return all_finished
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
......
import functools
import time
from collections import deque
from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union
import torch
from typing_extensions import TypeVar, assert_never
import vllm.envs as envs
......@@ -29,6 +33,7 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
......@@ -36,8 +41,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
......@@ -77,6 +81,28 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional[MultiModalDataDict]]
@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
allow_async_output_proc: bool = False
last_output: Optional[SamplerOutput] = None
@dataclass
class SchedulerContext:
output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
List[SequenceGroupMetadata], SchedulerOutputs,
bool,
bool]] = field(default_factory=lambda: deque())
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = field(
default_factory=lambda: [])
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
......@@ -162,11 +188,15 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
......@@ -176,7 +206,8 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"enable_prefix_caching=%s)",
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
......@@ -184,6 +215,7 @@ class LLMEngine:
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
......@@ -205,7 +237,9 @@ class LLMEngine:
model_config.seed,
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
......@@ -224,6 +258,7 @@ class LLMEngine:
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
self.step_return_finished_only = step_return_finished_only
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
......@@ -307,13 +342,36 @@ class LLMEngine:
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callbacks = [
functools.partial(self._process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [
Scheduler(scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size)
for _ in range(parallel_config.pipeline_parallel_size)
Scheduler(
scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if model_config.use_async_output_proc else None)
for v_id in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging.
......@@ -421,6 +479,13 @@ class LLMEngine:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutor
executor_class = RayXPUExecutor
elif distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
......@@ -1163,34 +1228,68 @@ class LLMEngine:
return
def _process_model_outputs(
self,
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
def _process_model_outputs(self, ctx: SchedulerContext) -> None:
"""Apply the model output to the sequences in the scheduled seq groups.
virtual_engine: The engine id to operate on
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
sampler_output: Used with multi-step execution to provide
sampler_output of each step
is_last_output: Used with multi-step execution to indicate
the last step (of each multi-step group)
Returns RequestOutputs that can be returned to the client.
"""
now = time.time()
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
output, num_seq_groups=len(scheduled_seq_groups))
if len(ctx.output_queue) == 0:
return None
# Get pending async postprocessor
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step) = ctx.output_queue.popleft()
assert outputs is not None
# Sanity check
assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups)
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if len(outputs) > 1:
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list))
else:
outputs_by_sequence_group = outputs
finished_before: List[int] = []
finished_now: List[int] = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if output is not None and len(output) > 0:
for o in output:
if seq_group.is_finished():
finished_before.append(i)
continue
if len(outputs) > 1:
output = outputs_by_sequence_group[i]
else:
output = [outputs_by_sequence_group[0][i]]
if not is_async:
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if outputs:
for o in outputs:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
......@@ -1205,30 +1304,105 @@ class LLMEngine:
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
continue
self._process_sequence_group_outputs(seq_group, output)
else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
seq_group, output, is_async)
self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)
if seq_group.is_finished():
finished_now.append(i)
# Free the finished sequence groups.
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# Generate outputs for the requests that finished this iteration
for i in finished_now:
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
# Create the outputs.
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
ctx.request_outputs.append(request_output)
# Free currently finished requests
if finished_now:
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# For multi-step, do not create outputs each iteration
if not is_last_step:
# Immediately process request outputs here (if callback is given)
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
return
# Create the outputs
# Note: scheduled_seq_groups and seq_group_metadata_list
# must match with the indices
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
if i in finished_before or i in finished_now:
continue # Avoids double processing
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished()
if self.step_return_finished_only else True):
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
return request_outputs
ctx.request_outputs.append(request_output)
# Immediately process request outputs here (if callback is given)
if (ctx.request_outputs
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx.request_outputs)
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if is_async:
# Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before)
# Tracing
self.do_tracing(scheduler_outputs)
return None
def _advance_to_next_step(
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done inside output processor, but it is
required if the worker is to perform async forward pass to next step.
"""
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output, scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
continue
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)")
sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
......@@ -1286,16 +1460,60 @@ class LLMEngine:
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
if self.scheduler_config.num_scheduler_steps > 1:
raise NotImplementedError(
"Multiple scheduler steps (multi-step) are only supported "
"through AsyncLLMEngine. ")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0.
virtual_engine = 0
# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()
virtual_engine].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
......@@ -1303,23 +1521,74 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids)
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
output = []
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# Add results to the output_queue
is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))
if output and allow_async_output_proc:
assert len(output) == 1, (
"Async postprocessor expects only a single output set")
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Check if need to run the usual non-async path
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
# Tracing
self.do_tracing(scheduler_outputs)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
return ctx.request_outputs
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
......@@ -1327,32 +1596,97 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()
return request_outputs
return ctx.request_outputs
def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))
return ref_remaining_steps > 0
def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs,
allow_async_output_proc: bool) -> None:
co = self.cached_scheduler_outputs[virtual_engine]
co.seq_group_metadata_list = seq_group_metadata_list
co.scheduler_outputs = scheduler_outputs
co.allow_async_output_proc = allow_async_output_proc
co.last_output = None
def _update_cached_scheduler_output(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_cpu is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output
def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
cached_last_output = self.cached_scheduler_outputs[
virtual_engine].last_output
if (self.scheduler_config.is_multi_step
and self.parallel_config.pipeline_parallel_size > 1
and cached_last_output is not None
and cached_last_output.sampled_token_ids_cpu is not None):
return cached_last_output.sampled_token_ids_cpu
return None
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if not self.log_stats:
raise RuntimeError(
"Stat logging is disabled. Set `disable_log_stats=False` "
"argument to enable.")
if logger_name in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} already exists.")
self.stat_loggers[logger_name] = logger
def remove_logger(self, logger_name: str) -> None:
if not self.log_stats:
raise RuntimeError(
"Stat logging is disabled. Set `disable_log_stats=False` "
"argument to enable.")
if logger_name not in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} does not exist.")
del self.stat_loggers[logger_name]
def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
def do_log_stats(self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
finished_before: Optional[List[int]] = None) -> None:
"""Forced log when no requests active."""
if self.log_stats:
stats = self._get_stats(scheduler_outputs, model_output)
stats = self._get_stats(scheduler_outputs, model_output,
finished_before)
for logger in self.stat_loggers.values():
logger.log(stats)
def _get_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
def _get_stats(self,
scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None,
finished_before: Optional[List[int]] = None) -> Stats:
"""Get Stats to be Logged to Prometheus.
Args:
......@@ -1417,6 +1751,10 @@ class LLMEngine:
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
if scheduler_outputs is not None:
# For async postprocessor, already finished sequences need to be
# not counted (to avoid double counting)
actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
num_generation_tokens_from_prefill_groups = 0.
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# the len of scheduler_outputs.scheduled_seq_groups is !=
......@@ -1425,6 +1763,11 @@ class LLMEngine:
for idx, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
# Skip double logging when using async output proc
if finished_before and idx in finished_before:
actual_num_batched_tokens -= 1
continue
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
seq_group = scheduled_seq_group.seq_group
......@@ -1459,7 +1802,6 @@ class LLMEngine:
# Latency timings
time_e2e_requests.append(now -
seq_group.metrics.arrival_time)
# Metadata
num_prompt_tokens_requests.append(
len(seq_group.prompt_token_ids))
......@@ -1483,7 +1825,7 @@ class LLMEngine:
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter = (
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
actual_num_batched_tokens - num_prompt_tokens_iter +
num_generation_tokens_from_prefill_groups)
# Spec decode, if enabled, emits specialized metrics from the worker in
......@@ -1633,7 +1975,26 @@ class LLMEngine:
def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
prompt_key = "encoder_prompt_token_ids" \
if self.is_encoder_decoder_model() else "prompt_token_ids"
if not inputs.get(prompt_key):
raise ValueError("Prompt cannot be empty")
\ No newline at end of file
if self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
else:
prompt_ids = inputs.get("prompt_token_ids")
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
......@@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC):
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
return SingleStepOutputProcessor(
scheduler_config,
detokenizer,
scheduler,
seq_counter,
stop_checker,
)
return SingleStepOutputProcessor(scheduler_config, detokenizer,
scheduler, seq_counter,
stop_checker)
else:
# Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import (
......@@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC):
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
......
......@@ -4,6 +4,8 @@ from typing import Callable, List
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.single_step import (
single_step_process_prompt_logprob)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
......@@ -46,9 +48,16 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
self._log_prompt_logprob_unsupported_warning_once()
"""Process prompt logprobs associated with each step of a multi-step-
scheduled computation.
Args:
seq_group: the outputs are associated with this :class:`SequenceGroup`
outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
"""
for output in outputs:
# Concatenate single-step prompt logprob processing results.
single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod
@functools.lru_cache()
......@@ -57,37 +66,73 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
def process_outputs(self,
sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization,
including freeing finished sequences. It also handles cases where there
are tokens emitted after the EOS token.
This applies logic like stop condition checking and detokenization.
It also handles cases where there are tokens emitted after
the EOS token.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
# if a client disconnects from the api server.
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
if seqs is None:
seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED)
assert seqs, "expected running sequences"
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [output.samples[0] for output in outputs]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
if is_async:
# Async case: We process tokens one by one. Here, we know the token
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params)
else:
# Standard multi-step case
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples = [output.samples[0] for output in outputs]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
new_char_count = 0
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params,
)
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
......@@ -125,20 +170,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs=output_logprob,
)
new_char_count = 0
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
self._process_decode_and_stop(seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params,
)
if seq.is_finished():
break
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
......@@ -15,6 +15,44 @@ from vllm.utils import Counter
logger = init_logger(__name__)
def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: SequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs.
Args:
sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
prompt_logprobs = output.prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []
assert hasattr(sg_output_proc, 'detokenizer')
if (seq_group.sampling_params.detokenize
and sg_output_proc.detokenizer):
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))
seq_group.prompt_logprobs.extend(prompt_logprobs)
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
......@@ -29,14 +67,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
that is currently difficult to schedule multiple steps ahead of time.
"""
def __init__(
self,
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
stop_checker: StopChecker,
):
def __init__(self, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, scheduler: List[Scheduler],
seq_counter: Counter, stop_checker: StopChecker):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
......@@ -44,50 +77,49 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0])
return self._process_sequence_group_outputs(sequence_group, outputs[0],
is_async)
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process prompt logprobs associated with one step of a single-step-
scheduled computation.
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []
if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))
seq_group.prompt_logprobs.extend(prompt_logprobs)
single_step_process_prompt_logprob(self, seq_group, output)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.n == 1 and not sampling_params.use_beam_search:
if sampling_params.best_of == 1 and not sampling_params.use_beam_search:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
......@@ -104,6 +136,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler.free_seq(seq)
return
# TODO: Add support for async for beam search
assert not is_async
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
......
......@@ -2,7 +2,8 @@ from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import PoolerOutput, SequenceGroupOutput
def create_output_by_sequence_group(
......
......@@ -5,11 +5,11 @@ from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
......
import asyncio
import codecs
from dataclasses import dataclass
from functools import lru_cache
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union)
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
Mapping, Optional, Tuple, TypeVar, Union, cast)
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import ChatCompletionContentPartImageParam
from openai.types.chat import (ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import ChatCompletionContentPartTextParam
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartTextParam)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict, TypeAdapter
from pydantic import ConfigDict
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image)
async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
......@@ -51,7 +59,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam, ]
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
......@@ -69,21 +78,217 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
same role.
"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
# TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict):
role: str
content: str
class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Optional[str]
"""The contents of the message"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
name: Optional[str]
"""The name of the function to call"""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
ModalityStr = Literal["image", "audio"]
_T = TypeVar("_T")
class BaseMultiModalItemTracker(ABC, Generic[_T]):
"""
Tracks multi-modal items in a given request and ensures that the number
of multi-modal items in a given request does not exceed the configured
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
super().__init__()
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._items: List[_T] = []
@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
return tokenizer.decode(token_index)
def _placeholder_str(self, modality: ModalityStr,
current_count: int) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
hf_config = self._model_config.hf_config
model_type = hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._items.append(item)
return self._placeholder_str(modality, current_count)
@abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser":
raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
class AsyncMultiModalItemTracker(
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items:
items = await asyncio.gather(*self._items)
return self._combine(items)
return None
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
class BaseMultiModalContentParser(ABC):
def __init__(self) -> None:
super().__init__()
# multimodal placeholder_string : count
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
def _add_placeholder(self, placeholder: Optional[str]):
if placeholder:
self._placeholder_counts[placeholder] += 1
def mm_placeholder_counts(self) -> Dict[str, int]:
return dict(self._placeholder_counts)
@abstractmethod
def parse_image(self, image_url: str) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
class MultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: MultiModalItemTracker) -> None:
super().__init__()
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]]
self._tracker = tracker
def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
super().__init__()
self._tracker = tracker
def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
def load_chat_template(
......@@ -112,152 +317,150 @@ def load_chat_template(
return resolved_chat_template
@lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
"""Combine multimodal prompts for a multimodal language model."""
# Look through the text prompt to check for missing placeholders
missing_placeholders: List[str] = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
if placeholder_counts[placeholder] < 0:
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")
# NOTE: For now we assume all model architectures use the same
# placeholder + text prompt format. This may change in the future.
return f"{placeholder_token_str}\n{text_prompt}"
missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])
# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image"
mm_parser = mm_tracker.create_parser()
for part in parts:
part_type = part["type"]
if part_type == "text":
text = _TextParser.validate_python(part)["text"]
text = _TextParser(part)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
image_url = _ImageParser.validate_python(part)["image_url"]
image_url = _ImageParser(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
mm_parser.parse_image(image_url["url"])
elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
audio_url = _AudioParser(part)["audio_url"]
mm_parser.parse_audio(audio_url["url"])
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
if mm_futures:
placeholder_token_str = _mm_token_str(model_config, tokenizer,
modality)
if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning(
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_multimodal_text_prompt(
placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt,
)
return [ConversationMessage(role=role, content=text_prompt)]
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[], mm_futures=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
return _parse_chat_message_content_parts(
result = _parse_chat_message_content_parts(
role,
content, # type: ignore
model_config,
tokenizer,
mm_tracker,
)
for result_msg in result:
if role == 'assistant':
parsed_msg = _AssistantParser(message)
if "tool_calls" in parsed_msg:
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool":
parsed_msg = _ToolParser(message)
if "tool_call_id" in parsed_msg:
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"]
return result
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
sub_messages = _parse_chat_message_content(msg, mm_tracker)
conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
conversation.extend(sub_messages)
return conversation, mm_futures
return conversation, mm_tracker.all_mm_data()
def parse_chat_messages_futures(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker)
conversation.extend(sub_messages)
return conversation, mm_tracker.all_mm_data()
def apply_chat_template(
......@@ -267,19 +470,31 @@ def apply_chat_template(
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
) -> Union[str, List[int]]:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in conversation:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for i in range(len(message["tool_calls"])):
args: str = message["tool_calls"][i]["function"]["arguments"]
parsed_args: Dict = json.loads(args)
message["tool_calls"][i]["function"]["arguments"] = parsed_args
prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)
return prompt
......@@ -23,7 +23,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs
from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
......@@ -129,6 +129,7 @@ class LLM:
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
**kwargs,
) -> None:
'''
......@@ -170,6 +171,7 @@ class LLM:
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
......@@ -356,15 +358,18 @@ class LLM:
add_generation_prompt: bool = True,
) -> List[RequestOutput]:
"""
Generates responses for chat messages.
Generate responses for a chat conversation.
Converts the messages to prompts using the tokenizer and calls
the :meth:`generate` method to generate the responses.
The chat conversation is converted into a text prompt using the
tokenizer and calls the :meth:`generate` method to generate the
responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of messages to generate responses for. Each
message is a list of dictionaries with 'role' and 'content'
keys.
messages: A single conversation represented as a list of messages.
Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
......@@ -385,18 +390,28 @@ class LLM:
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)
conversation, mm_data = parse_chat_messages(messages, model_config,
tokenizer)
prompts = apply_chat_template(
prompt = apply_chat_template(
tokenizer,
conversations,
conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt)
add_generation_prompt=add_generation_prompt,
)
inputs: PromptInputs
if is_list_of(prompt, int):
inputs = TokensPrompt(prompt_token_ids=prompt)
else:
inputs = TextPrompt(prompt=prompt)
if mm_data is not None:
inputs["multi_modal_data"] = mm_data
return self.generate(
prompts,
sampling_params,
inputs,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
......@@ -603,7 +618,6 @@ class LLM:
inputs = [inputs]
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
......@@ -678,6 +692,10 @@ class LLM:
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
)
# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
......@@ -700,6 +718,10 @@ class LLM:
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
# Restore original behavior
self.llm_engine.step_return_finished_only = False
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
......
......@@ -67,7 +67,7 @@ _running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: str) -> bool:
quantization: Optional[str]) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
......@@ -96,13 +96,6 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
......@@ -112,14 +105,37 @@ async def build_async_engine_client(
# Backend itself still global for the silly lil' health handler
global async_engine_client
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
async_engine_client = engine # type: ignore[assignment]
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code,
args.quantization)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization)
or disable_frontend_multiprocessing):
engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client
try:
yield engine_client
finally:
engine_client.shutdown_background_loop()
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
......@@ -148,7 +164,6 @@ async def build_async_engine_client(
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
# Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn")
......@@ -174,7 +189,7 @@ async def build_async_engine_client(
yield None
return
yield async_engine_client
yield rpc_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
......@@ -218,7 +233,7 @@ def mount_metrics(app: FastAPI):
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
......@@ -268,11 +283,14 @@ async def show_version():
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
......@@ -407,7 +425,8 @@ async def init_app(
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
openai_serving_completion = OpenAIServingCompletion(
async_engine_client,
model_config,
......
......@@ -163,6 +163,24 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser.add_argument(
"--enable-auto-tool-choice",
action="store_true",
default=False,
help=
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes"],
default=None,
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
......
......@@ -5,8 +5,9 @@ from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionContentPartParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
......@@ -35,6 +36,26 @@ assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id: Optional[str]
tool_calls: Optional[List[dict]]
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
......@@ -85,9 +106,19 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens: Optional[int] = 0
class JsonSchemaResponseFormat(OpenAIBaseModel):
name: str
description: Optional[str] = None
# schema is the field in openai but that causes conflicts with pydantic so
# instead use json_schema with an alias
json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
strict: Optional[bool] = None
class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
# type must be "json_schema", "json_object" or "text"
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StreamOptions(OpenAIBaseModel):
......@@ -135,8 +166,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
# NOTE this will be ignored by VLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
......@@ -318,6 +352,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
if isinstance(data, ValueError):
raise data
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
......@@ -329,21 +366,61 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and "tool_choice" in data and data[
"tool_choice"] != "none":
if guide_count > 1 and data.get("tool_choice",
"none") not in ("none", "auto"):
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
def check_tool_usage(cls, data):
# if "tool_choice" is not specified but tools are provided,
# default to "auto" tool_choice
if "tool_choice" not in data and "tools" in data:
data["tool_choice"] = "auto"
# if "tool_choice" is specified -- validation
if "tool_choice" in data:
# ensure that if "tool choice" is specified, tools are present
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool
# OR that it's set to "auto"
if data["tool_choice"] != "auto" and not isinstance(
data["tool_choice"], dict):
raise ValueError(
"`tool_choice` must either be a named tool or \"auto\". "
"`tool_choice=\"none\" is not supported.")
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
if isinstance(data["tool_choice"], dict):
valid_tool = False
specified_function = data["tool_choice"]["function"]
if not specified_function:
raise ValueError(
"Incorrectly formatted `tool_choice`. Should be like "
"`{\"type\": \"function\","
" \"function\": {\"name\": \"my_function\"}}`")
specified_function_name = specified_function["name"]
if not specified_function_name:
raise ValueError(
"Incorrectly formatted `tool_choice`. Should be like "
"`{\"type\": \"function\", "
"\"function\": {\"name\": \"my_function\"}}`")
for tool in data["tools"]:
if tool["function"]["name"] == specified_function_name:
valid_tool = True
break
if not valid_tool:
raise ValueError(
"The tool specified in `tool_choice` does not match any"
" of the specified `tools`")
return data
......@@ -403,7 +480,7 @@ class CompletionRequest(OpenAIBaseModel):
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
description="If specified, the output will follow the JSON schema.",
)
guided_regex: Optional[str] = Field(
default=None,
......@@ -623,9 +700,41 @@ class ToolCall(OpenAIBaseModel):
function: FunctionCall
class DeltaFunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
index: int
function: Optional[DeltaFunctionCall] = None
# the initial delta that gets sent once a new tool call is started;
class InitialDeltaToolCall(DeltaToolCall):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
index: int
class ExtractedToolCallInformation(BaseModel):
# indicate if tools were called
tools_called: bool
# extracted tool calls
tool_calls: List[ToolCall]
# content - per OpenAI spec, content AND tool calls can be returned rarely
# But some models will do this intentionally
content: Optional[str] = None
class ChatMessage(OpenAIBaseModel):
role: str
content: str
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
......@@ -647,7 +756,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None
# per OpenAI spec this is the default
finish_reason: Optional[str] = "stop"
# not part of the OpenAI spec but included in vLLM for legacy reasons
stop_reason: Optional[Union[int, str]] = None
......@@ -664,7 +775,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
......
import asyncio
import pickle
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
from uuid import uuid4
import cloudpickle
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
......@@ -101,6 +104,7 @@ class AsyncEngineRPCClient:
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
assert isinstance(socket_limit, int)
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
......@@ -114,18 +118,21 @@ class AsyncEngineRPCClient:
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
self.from_api_server: Socket = self.context.socket(
zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task(
self.proxy_in_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
self.proxy_out_task = asyncio.create_task(
self.run_proxy(self.to_rpc_server, self.from_api_server))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
......@@ -135,20 +142,11 @@ class AsyncEngineRPCClient:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2
async def run_proxy(self, socket_from, socket_to):
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
"""Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])
frames = await socket_from.recv_multipart(copy=False)
await socket_to.send_multipart(frames, copy=False)
async def setup(self):
"""Setup the client before it starts sending server requests."""
......@@ -179,7 +177,7 @@ class AsyncEngineRPCClient:
self.context.destroy()
@contextmanager
def to_proxy_socket(self):
def to_proxy_socket(self) -> Iterator[Socket]:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
......@@ -207,7 +205,8 @@ class AsyncEngineRPCClient:
with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ),
copy=False)
# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
......@@ -215,7 +214,9 @@ class AsyncEngineRPCClient:
f"{self._data_timeout} ms")
# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
data = pickle.loads(frame.buffer)
if isinstance(data, Exception):
# Re-raise exceptions returned by the server
......@@ -233,23 +234,23 @@ class AsyncEngineRPCClient:
return data
async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[zmq.asyncio.Socket] = None):
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[Socket] = None):
"""Send one-way RPC request to trigger an action."""
async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE):
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
await socket.send_multipart([cloudpickle.dumps(request)])
await socket.send_multipart((cloudpickle.dumps(request), ))
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
return cloudpickle.loads(await socket.recv())
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
return pickle.loads(frame.buffer)
# Make a new socket connection.
if socket is None:
......@@ -385,21 +386,20 @@ class AsyncEngineRPCClient:
try:
with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
await socket.send_multipart((cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)), ))
# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv()
request_output = cloudpickle.loads(message)
message = await socket.recv(copy=False)
assert isinstance(message, Frame)
request_output = pickle.loads(message.buffer)
if isinstance(request_output, Exception):
# On exception, check if the server is still healthy
......@@ -423,9 +423,7 @@ class AsyncEngineRPCClient:
if not finished and not self._errored:
await self.abort(request_id)
async def check_health(self,
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
async def check_health(self, socket: Optional[Socket] = None) -> None:
"""Raise if unhealthy"""
await self._send_one_way_rpc_request(
......@@ -450,4 +448,4 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")
\ No newline at end of file
error_message="RPCRequest STOP_PROFILE failed.")
import asyncio
import pickle
import signal
from typing import Any, Coroutine, Union
......@@ -7,6 +8,8 @@ import uvloop
import zmq
import zmq.asyncio
from typing_extensions import Never
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
......@@ -35,7 +38,7 @@ class AsyncEngineRPCServer:
self.context = zmq.asyncio.Context()
# Init socket.
self.socket = self.context.socket(zmq.constants.DEALER)
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)
......@@ -63,30 +66,31 @@ class AsyncEngineRPCServer:
else:
raise ValueError("Unknown Config Request: %s", request)
await self.socket.send_multipart(
[identity, cloudpickle.dumps(config)])
await self.socket.send_multipart((identity, pickle.dumps(config)),
copy=False)
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
(identity, pickle.dumps(tracing_flag)))
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
......@@ -96,7 +100,7 @@ class AsyncEngineRPCServer:
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
except Exception as e:
result = e
await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
await self.socket.send_multipart((identity, pickle.dumps(result)))
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
......@@ -110,45 +114,47 @@ class AsyncEngineRPCServer:
async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
(identity, pickle.dumps(request_output)), copy=False)
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")
await self.socket.send_multipart([
await self.socket.send_multipart((
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))
async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")
await self.socket.send_multipart([
await self.socket.send_multipart((
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
message: Frame) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message)
request = cloudpickle.loads(message.buffer)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
......@@ -189,7 +195,7 @@ class AsyncEngineRPCServer:
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
identity, message = await self.socket.recv_multipart(copy=False)
# Process the request async.
task = asyncio.create_task(
......
......@@ -3,10 +3,11 @@ from io import StringIO
from typing import Awaitable, Callable, List
import aiohttp
from prometheus_client import start_http_server
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.logger import RequestLogger, logger
# yapf: disable
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
......@@ -16,13 +17,10 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
def parse_args():
parser = FlexibleArgumentParser(
......@@ -59,6 +57,24 @@ def parse_args():
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
parser.add_argument("--enable-metrics",
action="store_true",
help="Enable Prometheus metrics")
parser.add_argument(
"--url",
type=str,
default="0.0.0.0",
help="URL to the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
return parser.parse_args()
......@@ -184,7 +200,15 @@ async def main(args):
if __name__ == "__main__":
args = parse_args()
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
logger.info("args: %s", args)
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url)
else:
logger.info("Prometheus metrics disabled")
asyncio.run(main(args))
import asyncio
import json
import time
from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Union
......@@ -11,22 +13,25 @@ from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
load_chat_template,
parse_chat_messages)
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
FunctionCall, ToolCall, UsageInfo)
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
PromptAdapterPath,
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
MistralToolParser,
ToolParser)
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
......@@ -38,19 +43,19 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
def __init__(self,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None):
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
......@@ -60,10 +65,27 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids)
self.response_role = response_role
# If this is None we use the tokenizer's default chat template
self.use_tool_use_model_template = False
self.chat_template = load_chat_template(chat_template)
# set up tool use
self.enable_auto_tools: bool = enable_auto_tools
if self.enable_auto_tools:
logger.info(
"\"auto\" tool choice has been enabled please note that while"
" the parallel_tool_calls client option is preset for "
"compatibility reasons, it will be ignored.")
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
if tool_parser == "mistral":
self.tool_parser = MistralToolParser
elif tool_parser == "hermes":
self.tool_parser = Hermes2ProToolParser
else:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
async def create_chat_completion(
self,
request: ChatCompletionRequest,
......@@ -76,11 +98,10 @@ class OpenAIServingChat(OpenAIServing):
for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
try:
......@@ -93,7 +114,7 @@ class OpenAIServingChat(OpenAIServing):
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
conversation, mm_futures = parse_chat_messages(
conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer)
tool_dicts = None if request.tools is None else [
......@@ -113,30 +134,47 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
mm_data: Optional[MultiModalDataDict] = None
try:
if len(mm_futures):
# since we support only single mm data currently
assert len(
mm_futures
) == 1, "Multiple 'image_url' input is currently not supported."
mm_data = await mm_futures[0]
mm_data = await mm_data_future
except Exception as e:
logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e))
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
# "auto" tools requires --enable-auto-tool-choice
# and --tool-call-parser
if request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None):
return self.create_error_response(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")
request_id = f"chat-{random_uuid()}"
try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
assert prompt_inputs is not None
sampling_params = request.to_sampling_params(
tokenizer,
......@@ -184,6 +222,7 @@ class OpenAIServingChat(OpenAIServing):
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer)
......@@ -216,6 +255,9 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
try:
async for res in result_generator:
# We need to do it here, because if there are exceptions in
......@@ -225,6 +267,9 @@ class OpenAIServingChat(OpenAIServing):
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
......@@ -237,14 +282,18 @@ class OpenAIServingChat(OpenAIServing):
created=created_time,
choices=[choice_data],
model=model_name)
# if usage should be included
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
chunk.usage = None
......@@ -254,7 +303,7 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content = ""
last_msg_content: Optional[str] = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
......@@ -295,6 +344,7 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
......@@ -317,20 +367,50 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
delta_message: Optional[DeltaMessage] = None
if request.tool_choice and type(
request.tool_choice
) is ChatCompletionNamedToolChoiceParam:
# handle streaming deltas for tools with named tool_choice
if (request.tool_choice and type(request.tool_choice) is
ChatCompletionNamedToolChoiceParam):
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
DeltaToolCall(function=DeltaFunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text))
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
elif (self._should_stream_with_auto_tool_parsing(request)
and tool_parser):
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_texts[i],
current_text=output.text,
delta_text=delta_text,
previous_token_ids= \
output.token_ids[
:-1 * len(delta_token_ids)
],
current_token_ids=output.token_ids,
delta_token_ids=delta_token_ids
)
)
# handle streaming just a content delta
else:
delta_message = DeltaMessage(content=delta_text)
# set the previous values for the next iteration
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
# wasn't ready to send a token, then
# get the next token without streaming a chunk
if delta_message is None:
continue
if output.finish_reason is None:
# Send token-by-token response for each request.n
......@@ -345,6 +425,8 @@ class OpenAIServingChat(OpenAIServing):
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
......@@ -362,14 +444,55 @@ class OpenAIServingChat(OpenAIServing):
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# if the model is finished generating
else:
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using guided decoding
if tool_parser:
index = len(
tool_parser.prev_tool_call_arr) - 1 if len(
tool_parser.prev_tool_call_arr) > 0 else 0
else:
index = 0
if self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output) and tool_parser:
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}))
# get what we've streamed so for for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[
index]
# check to see if there's anything left to stream
remaining_call = expected_call.replace(
actual_call, "", 1)
# set that as a delta message
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(index=index,
function=DeltaFunctionCall(
arguments=remaining_call).
model_dump(exclude_none=True))
])
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason,
finish_reason=output.finish_reason
if not (tool_parser
and len(tool_parser.prev_tool_call_arr))
else "tool_calls",
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
......@@ -395,6 +518,8 @@ class OpenAIServingChat(OpenAIServing):
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
final_usage = UsageInfo(
......@@ -416,6 +541,7 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
......@@ -460,8 +586,21 @@ class OpenAIServingChat(OpenAIServing):
else:
logprobs = None
if request.tool_choice and type(
# by default, tools are not used.
tools_called = False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if not (self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
message = ChatMessage(role=role, content=output.text)
# if the request uses tools and specified a tool choice
elif request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
message = ChatMessage(
role=role,
content="",
......@@ -470,14 +609,47 @@ class OpenAIServingChat(OpenAIServing):
name=request.tool_choice.function.name,
arguments=output.text))
])
tools_called = True
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text)
# handle when there are tools and tool choice is auto
elif request.tools and (
request.tool_choice == "auto"
or request.tool_choice is None) and self.enable_auto_tools \
and self.tool_parser:
tool_parser = self.tool_parser(tokenizer)
tool_call_info = tool_parser.extract_tool_calls(output.text)
tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,
content=tool_call_info.content,
tool_calls=tool_call_info.tool_calls)
else:
# FOR NOW make it a chat message; we will have to detect
# the type to make it later.
message = ChatMessage(role=role, content=output.text)
# undetermined case that is still important to handle
else:
logger.error(
"Error in chat_completion_full_generator - cannot determine"
" if tools should be extracted. Returning a standard chat "
"completion.")
message = ChatMessage(role=role, content=output.text)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason=output.finish_reason,
finish_reason="tool_calls" if tools_called else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason)
choices.append(choice_data)
......@@ -485,10 +657,11 @@ class OpenAIServingChat(OpenAIServing):
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"]
last_msg_content = conversation[-1]["content"] or ""
for choice in choices:
full_message = last_msg_content + choice.message.content
full_message = last_msg_content + (choice.message.content
or "")
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
......@@ -571,3 +744,38 @@ class OpenAIServingChat(OpenAIServing):
))
return ChatCompletionLogProbs(content=logprobs_content)
def _should_stream_with_auto_tool_parsing(self,
request: ChatCompletionRequest):
"""
Utility function to check if streamed tokens should go through the tool
call parser that was configured.
We only want to do this IF user-provided tools are set, a tool parser
is configured, "auto" tool choice is enabled, and the request's tool
choice field indicates that "auto" tool choice should be used.
"""
return (request.tools and self.tool_parser and self.enable_auto_tools
and request.tool_choice in ['auto', None])
def _should_check_for_unstreamed_tool_arg_tokens(
self,
delta_message: Optional[DeltaMessage],
output: CompletionOutput,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
This is only applicable when auto tool parsing is enabled, the delta
is a tool call with arguments.
"""
# yapf: disable
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
and output.finish_reason is not None
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment