Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
...@@ -2,8 +2,8 @@ import asyncio ...@@ -2,8 +2,8 @@ import asyncio
import time import time
import weakref import weakref
from functools import partial from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
Mapping, Optional, Set, Tuple, Type, Union) List, Mapping, Optional, Set, Tuple, Type, Union, overload)
from weakref import ReferenceType from weakref import ReferenceType
import vllm.envs as envs import vllm.envs as envs
...@@ -14,12 +14,15 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -14,12 +14,15 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -28,7 +31,7 @@ from vllm.sampling_params import SamplingParams ...@@ -28,7 +31,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind from vllm.utils import deprecate_kwargs, weak_bind
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
...@@ -363,11 +366,18 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -363,11 +366,18 @@ class _AsyncLLMEngine(LLMEngine):
self.cached_scheduler_outputs[ self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState() virtual_engine] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
ctx.append_output(outputs=outputs, ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs, scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc, is_async=allow_async_output_proc,
is_last_step=True) is_last_step=True,
is_first_step_output=is_first_step_output)
if outputs and allow_async_output_proc: if outputs and allow_async_output_proc:
assert len( assert len(
...@@ -402,31 +412,86 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -402,31 +412,86 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop.""" """Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async() await self.model_executor.stop_remote_worker_execution_loop_async()
@overload # DEPRECATED
async def add_request_async(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@overload
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request_async(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> None:
"""Async version of :meth:`add_request`.""" """Async version of :meth:`add_request`."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if priority != 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
preprocessed_inputs = await self.input_preprocessor.preprocess_async( preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs) processed_inputs = self.input_processor(preprocessed_inputs)
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
processed_inputs=processed_inputs, processed_inputs=processed_inputs,
...@@ -435,6 +500,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -435,6 +500,7 @@ class _AsyncLLMEngine(LLMEngine):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority,
) )
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
...@@ -443,7 +509,37 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -443,7 +509,37 @@ class _AsyncLLMEngine(LLMEngine):
self.model_executor.check_health() self.model_executor.check_health()
class AsyncLLMEngine: async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
return sampling_params
logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(processor)
# Unset guided decoding params after constructing the lp from them
sampling_params.guided_decoding = None
return sampling_params
class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for :class:`LLMEngine`. """An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to make it This class is used to wrap the :class:`LLMEngine` class to make it
...@@ -774,16 +870,58 @@ class AsyncLLMEngine: ...@@ -774,16 +870,58 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way # This method does not need to be async, but kept that way
# for backwards compatibility. # for backwards compatibility.
async def add_request( @overload # DEPRECATED
def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@overload
def add_request(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
self.start_background_loop() self.start_background_loop()
...@@ -794,26 +932,34 @@ class AsyncLLMEngine: ...@@ -794,26 +932,34 @@ class AsyncLLMEngine:
"error that caused the background loop to stop " "error that caused the background loop to stop "
"(AsyncEngineDeadError).") "(AsyncEngineDeadError).")
if (priority != 0
and not self.engine.scheduler_config.policy == "priority"):
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
verbose=self.log_requests, verbose=self.log_requests,
inputs=inputs, prompt=prompt,
params=params, params=params,
arrival_time=arrival_time or time.time(), arrival_time=arrival_time or time.time(),
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return stream.generator() return stream.generator()
async def generate( async def generate(
self, self,
inputs: PromptInputs, prompt: PromptType,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -822,8 +968,7 @@ class AsyncLLMEngine: ...@@ -822,8 +968,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
...@@ -831,6 +976,8 @@ class AsyncLLMEngine: ...@@ -831,6 +976,8 @@ class AsyncLLMEngine:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine The output `RequestOutput` objects from the LLMEngine
...@@ -881,21 +1028,23 @@ class AsyncLLMEngine: ...@@ -881,21 +1028,23 @@ class AsyncLLMEngine:
""" """
async for output in await self.add_request( async for output in await self.add_request(
request_id, request_id,
inputs, prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority,
): ):
yield LLMEngine.validate_output(output, RequestOutput) yield LLMEngine.validate_output(output, RequestOutput)
async def encode( async def encode(
self, self,
inputs: PromptInputs, prompt: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
...@@ -904,13 +1053,14 @@ class AsyncLLMEngine: ...@@ -904,13 +1053,14 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
...@@ -959,10 +1109,11 @@ class AsyncLLMEngine: ...@@ -959,10 +1109,11 @@ class AsyncLLMEngine:
""" """
async for output in await self.add_request( async for output in await self.add_request(
request_id, request_id,
inputs, prompt,
pooling_params, pooling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority,
): ):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput) yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
......
...@@ -6,7 +6,7 @@ from functools import partial ...@@ -6,7 +6,7 @@ from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional) Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, Union from typing import Set, Type, Union, cast, overload
import torch import torch
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -25,14 +25,17 @@ from vllm.engine.output_processor.interfaces import ( ...@@ -25,14 +25,17 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
InputRegistry, LLMInputs, PromptInputs) EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
...@@ -41,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -41,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
...@@ -51,7 +54,7 @@ from vllm.transformers_utils.tokenizer_group import ( ...@@ -51,7 +54,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device, weak_bind from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -90,6 +93,12 @@ class OutputData(NamedTuple): ...@@ -90,6 +93,12 @@ class OutputData(NamedTuple):
scheduler_outputs: SchedulerOutputs scheduler_outputs: SchedulerOutputs
is_async: bool is_async: bool
is_last_step: bool is_last_step: bool
# Indicates if this output is from the first step of the
# multi-step. When multi-step is disabled, this is always
# set to True.
# is_first_step_output is invalid when `outputs` has
# outputs from multiple steps.
is_first_step_output: Optional[bool]
skip: List[int] skip: List[int]
...@@ -108,13 +117,15 @@ class SchedulerContext: ...@@ -108,13 +117,15 @@ class SchedulerContext:
def append_output(self, outputs: List[SamplerOutput], def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool, scheduler_outputs: SchedulerOutputs, is_async: bool,
is_last_step: bool): is_last_step: bool,
is_first_step_output: Optional[bool]):
self.output_queue.append( self.output_queue.append(
OutputData(outputs=outputs, OutputData(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs, scheduler_outputs=scheduler_outputs,
is_async=is_async, is_async=is_async,
is_last_step=is_last_step, is_last_step=is_last_step,
is_first_step_output=is_first_step_output,
skip=[])) skip=[]))
...@@ -177,7 +188,7 @@ class LLMEngine: ...@@ -177,7 +188,7 @@ class LLMEngine:
raise TypeError(f"Expected output of type {output_type}, " raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}") f"but found type {type(output)}")
return output return cast(_O, output)
@classmethod @classmethod
def validate_outputs( def validate_outputs(
...@@ -236,10 +247,11 @@ class LLMEngine: ...@@ -236,10 +247,11 @@ class LLMEngine:
"enforce_eager=%s, kv_cache_dtype=%s, " "enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, " "decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, " "num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"enable_prefix_caching=%s, use_async_output_proc=%s, " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)", "use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)",
VLLM_VERSION, VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
...@@ -268,8 +280,8 @@ class LLMEngine: ...@@ -268,8 +280,8 @@ class LLMEngine:
observability_config, observability_config,
model_config.seed, model_config.seed,
model_config.served_model_name, model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps, scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs, scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching, cache_config.enable_prefix_caching,
model_config.use_async_output_proc, model_config.use_async_output_proc,
...@@ -277,9 +289,6 @@ class LLMEngine: ...@@ -277,9 +289,6 @@ class LLMEngine:
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
load_general_plugins()
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
...@@ -625,7 +634,7 @@ class LLMEngine: ...@@ -625,7 +634,7 @@ class LLMEngine:
def _add_processed_request( def _add_processed_request(
self, self,
request_id: str, request_id: str,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
...@@ -689,16 +698,51 @@ class LLMEngine: ...@@ -689,16 +698,51 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
@overload # DEPRECATED
def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@overload
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -708,8 +752,7 @@ class LLMEngine: ...@@ -708,8 +752,7 @@ class LLMEngine:
Args: Args:
request_id: The unique ID of the request. request_id: The unique ID of the request.
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
params: Parameters for sampling or pooling. params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation. :class:`~vllm.SamplingParams` for text generation.
...@@ -723,7 +766,7 @@ class LLMEngine: ...@@ -723,7 +766,7 @@ class LLMEngine:
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None. - Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects. - Create `n` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object - Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`. from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler. - Add the :class:`~vllm.SequenceGroup` object to the scheduler.
...@@ -744,11 +787,15 @@ class LLMEngine: ...@@ -744,11 +787,15 @@ class LLMEngine:
>>> # continue the request processing >>> # continue the request processing
>>> ... >>> ...
""" """
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if priority > 0 and not self.scheduler_config.policy == "priority": if priority != 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but " raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.") "Priority scheduling is not enabled.")
...@@ -756,13 +803,20 @@ class LLMEngine: ...@@ -756,13 +803,20 @@ class LLMEngine:
arrival_time = time.time() arrival_time = time.time()
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs) processed_inputs = self.input_processor(preprocessed_inputs)
# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
processed_inputs=processed_inputs, processed_inputs=processed_inputs,
...@@ -795,6 +849,9 @@ class LLMEngine: ...@@ -795,6 +849,9 @@ class LLMEngine:
raise ValueError(f"Cannot request more than " raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.") f"{max_logprobs} logprobs.")
sampling_params = self._build_logits_processors(
sampling_params, lora_request)
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone() sampling_params = sampling_params.clone()
...@@ -911,6 +968,45 @@ class LLMEngine: ...@@ -911,6 +968,45 @@ class LLMEngine:
return return
def _update_num_computed_tokens_for_multi_step_prefill(
self, seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata,
is_first_step_output: Optional[bool]):
"""
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
in multi-step are submitted in a single burst.
"""
assert self.scheduler_config.is_multi_step
if not seq_group_meta.is_prompt:
# num_computed_token updates for multi-step decodes happen after
# the tokens are appended to the sequence.
return
do_update: bool = False
if self.scheduler_config.chunked_prefill_enabled:
# In multi-step + chunked-prefill case, the prompt sequences
# that are scheduled are fully processed in the first step.
do_update = is_first_step_output is None or is_first_step_output
else:
# Normal multi-step decoding case. In this case prompt-sequences
# are actually single-stepped. Always update in this case.
assert seq_group.state.num_steps == 1
do_update = True
if do_update:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)
def _process_model_outputs(self, def _process_model_outputs(self,
ctx: SchedulerContext, ctx: SchedulerContext,
request_id: Optional[str] = None) -> None: request_id: Optional[str] = None) -> None:
...@@ -919,8 +1015,8 @@ class LLMEngine: ...@@ -919,8 +1015,8 @@ class LLMEngine:
ctx: The virtual engine context to work on ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed request_id: If provided, then only this request is going to be processed
""" """
now = time.time() now = time.time()
if len(ctx.output_queue) == 0: if len(ctx.output_queue) == 0:
...@@ -931,20 +1027,28 @@ class LLMEngine: ...@@ -931,20 +1027,28 @@ class LLMEngine:
# When we process only one request, no pop is required # When we process only one request, no pop is required
# (since later we will process all of the rest) # (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async, (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue[0] is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
else: else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async, (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue.popleft() is_last_step, is_first_step_output,
skip) = ctx.output_queue.popleft()
# Sanity check # Sanity check
assert len(seq_group_metadata_list) == len( assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
# Organize outputs by [step][sequence group] instead of has_multiple_outputs: bool = len(outputs) > 1
# [sequence group][step]. outputs_by_sequence_group: List[List[SequenceGroupOutput]]
if len(outputs) > 1: if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \
self.speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
outputs_by_sequence_group = create_output_by_sequence_group( outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list)) outputs, num_seq_groups=len(seq_group_metadata_list))
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output = None
else: else:
outputs_by_sequence_group = outputs outputs_by_sequence_group = outputs
...@@ -974,20 +1078,26 @@ class LLMEngine: ...@@ -974,20 +1078,26 @@ class LLMEngine:
seq_group_meta = seq_group_metadata_list[i] seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group seq_group: SequenceGroup = scheduled_seq_group.seq_group
if seq_group.is_finished(): if seq_group.is_finished():
finished_before.append(i) finished_before.append(i)
continue continue
if len(outputs) > 1: output: List[SequenceGroupOutput]
if has_multiple_outputs:
output = outputs_by_sequence_group[i] output = outputs_by_sequence_group[i]
else: else:
output = [outputs_by_sequence_group[0][i]] output = [outputs_by_sequence_group[0][i]]
if not is_async: if not is_async:
seq_group.update_num_computed_tokens( if self.scheduler_config.is_multi_step:
scheduled_seq_group.token_chunk_size) # Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size or 0)
if outputs: if outputs:
for o in outputs: for o in outputs:
...@@ -995,13 +1105,13 @@ class LLMEngine: ...@@ -995,13 +1105,13 @@ class LLMEngine:
and seq_group.metrics is not None): and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None: if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += ( seq_group.metrics.model_forward_time += (
o.model_forward_time) o.model_forward_time or 0)
else: else:
seq_group.metrics.model_forward_time = ( seq_group.metrics.model_forward_time = (
o.model_forward_time) o.model_forward_time)
if seq_group.metrics.model_execute_time is not None: if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += ( seq_group.metrics.model_execute_time += (
o.model_execute_time) o.model_execute_time or 0)
else: else:
seq_group.metrics.model_execute_time = ( seq_group.metrics.model_execute_time = (
o.model_execute_time) o.model_execute_time)
...@@ -1121,19 +1231,34 @@ class LLMEngine: ...@@ -1121,19 +1231,34 @@ class LLMEngine:
if seq_group.is_finished(): if seq_group.is_finished():
continue continue
seq_group.update_num_computed_tokens( if self.scheduler_config.is_multi_step:
seq_group_metadata.token_chunk_size) # Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_metadata,
seq_group.state.num_steps == 1)
else:
token_chunk_size = (seq_group_metadata.token_chunk_size
if seq_group_metadata.token_chunk_size
is not None else 0)
seq_group.update_num_computed_tokens(token_chunk_size)
if seq_group_metadata.do_sample: if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, ( assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample" "Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no " " (i.e sampling_params.n == 1)")
"sampling_params.best_of > 1)")
sample = sequence_group_outputs.samples[0] sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1 assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0] seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
if self.scheduler_config.is_multi_step:
is_prefill_append = seq.data.get_num_uncomputed_tokens(
) == 0
seq.append_token_id(sample.output_token, sample.logprobs)
if not is_prefill_append:
seq_group.update_num_computed_tokens(1)
else:
seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
...@@ -1286,12 +1411,19 @@ class LLMEngine: ...@@ -1286,12 +1411,19 @@ class LLMEngine:
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState() self.cached_scheduler_outputs[0] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue # Add results to the output_queue
ctx.append_output(outputs=outputs, ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs, scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc, is_async=allow_async_output_proc,
is_last_step=True) is_last_step=True,
is_first_step_output=is_first_step_output)
if outputs and allow_async_output_proc: if outputs and allow_async_output_proc:
assert len(outputs) == 1, ( assert len(outputs) == 1, (
...@@ -1482,7 +1614,6 @@ class LLMEngine: ...@@ -1482,7 +1614,6 @@ class LLMEngine:
# Metadata # Metadata
num_prompt_tokens_requests: List[int] = [] num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = [] num_generation_tokens_requests: List[int] = []
best_of_requests: List[int] = []
n_requests: List[int] = [] n_requests: List[int] = []
finished_reason_requests: List[str] = [] finished_reason_requests: List[str] = []
...@@ -1553,8 +1684,6 @@ class LLMEngine: ...@@ -1553,8 +1684,6 @@ class LLMEngine:
for seq in seq_group.get_finished_seqs() for seq in seq_group.get_finished_seqs()
]) ])
if seq_group.sampling_params is not None: if seq_group.sampling_params is not None:
best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n) n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([ finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status) SequenceStatus.get_finished_reason(seq.status)
...@@ -1607,7 +1736,6 @@ class LLMEngine: ...@@ -1607,7 +1736,6 @@ class LLMEngine:
# Metadata # Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests, num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests, num_generation_tokens_requests=num_generation_tokens_requests,
best_of_requests=best_of_requests,
n_requests=n_requests, n_requests=n_requests,
finished_reason_requests=finished_reason_requests, finished_reason_requests=finished_reason_requests,
) )
...@@ -1694,8 +1822,6 @@ class LLMEngine: ...@@ -1694,8 +1822,6 @@ class LLMEngine:
seq_group.sampling_params.top_p) seq_group.sampling_params.top_p)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS, seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
seq_group.sampling_params.max_tokens) seq_group.sampling_params.max_tokens)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
seq_group.sampling_params.best_of)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N, seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
seq_group.sampling_params.n) seq_group.sampling_params.n)
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES, seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
...@@ -1732,8 +1858,8 @@ class LLMEngine: ...@@ -1732,8 +1858,8 @@ class LLMEngine:
def is_embedding_model(self): def is_embedding_model(self):
return self.model_config.is_embedding_model return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[LLMInputs, def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderLLMInputs]): EncoderDecoderInputs]):
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length # restricts the decoder prompt length
...@@ -1760,4 +1886,52 @@ class LLMEngine: ...@@ -1760,4 +1886,52 @@ class LLMEngine:
# TODO: Find out how many placeholder tokens are there so we can # TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them # check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens # max_batch_len = self.scheduler_config.max_num_batched_tokens
\ No newline at end of file
def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""
logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:
logger.debug(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s", guided_decoding)
tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
logits_processors.append(processor)
# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
processors = get_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
logits_processors.extend(processors)
# Unset so these don't get passed down to the model
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None
if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
else:
sampling_params.logits_processors.extend(logits_processors)
return sampling_params
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Type, Union, cast
import numpy as np import numpy as np
import prometheus_client import prometheus_client
...@@ -134,12 +134,6 @@ class Metrics: ...@@ -134,12 +134,6 @@ class Metrics:
labelnames=labelnames, labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
) )
self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self._histogram_cls( self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n", name="vllm:request_params_n",
documentation="Histogram of the n request parameter.", documentation="Histogram of the n request parameter.",
...@@ -255,10 +249,11 @@ class _RayHistogramWrapper: ...@@ -255,10 +249,11 @@ class _RayHistogramWrapper:
labelnames: Optional[List[str]] = None, labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None): buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self._histogram = ray_metrics.Histogram(name=name, self._histogram = ray_metrics.Histogram(name=name,
description=documentation, description=documentation,
tag_keys=labelnames_tuple, tag_keys=labelnames_tuple,
boundaries=buckets) boundaries=boundaries)
def labels(self, **labels): def labels(self, **labels):
self._histogram.set_default_tags(labels) self._histogram.set_default_tags(labels)
...@@ -273,9 +268,12 @@ class RayMetrics(Metrics): ...@@ -273,9 +268,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library. Provides the same metrics as Metrics but uses Ray's util.metrics library.
""" """
_gauge_cls = _RayGaugeWrapper _gauge_cls: Type[prometheus_client.Gauge] = cast(
_counter_cls = _RayCounterWrapper Type[prometheus_client.Gauge], _RayGaugeWrapper)
_histogram_cls = _RayHistogramWrapper _counter_cls: Type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper)
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper)
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None: if ray_metrics is None:
...@@ -473,8 +471,6 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -473,8 +471,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.metrics.histogram_num_generation_tokens_request, self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests) stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)
def _log_prometheus_interval(self, prompt_throughput: float, def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None: generation_throughput: float) -> None:
......
...@@ -49,7 +49,6 @@ class Stats: ...@@ -49,7 +49,6 @@ class Stats:
# Metadata # Metadata
num_prompt_tokens_requests: List[int] num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int] num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int] n_requests: List[int]
finished_reason_requests: List[str] finished_reason_requests: List[str]
......
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Mapping, Optional, Union from typing import List, Mapping, Optional, Union, overload
from vllm import PoolingParams from vllm import PoolingParams
from vllm.inputs import PromptInputs from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_SUCCESS_STR = "SUCCESS"
...@@ -23,12 +24,71 @@ class MQEngineDeadError(RuntimeError): ...@@ -23,12 +24,71 @@ class MQEngineDeadError(RuntimeError):
@dataclass @dataclass
class RPCProcessRequest: class RPCProcessRequest:
inputs: PromptInputs prompt: PromptType
params: Union[SamplingParams, PoolingParams] params: Union[SamplingParams, PoolingParams]
request_id: str request_id: str
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
@overload # DEPRECATED
def __init__(
self,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@overload
def __init__(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def __init__(
self,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
if inputs is not None:
prompt = inputs
assert (prompt is not None and params is not None
and request_id is not None)
super().__init__()
self.prompt = prompt
self.params = params
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
@dataclass @dataclass
......
...@@ -2,8 +2,8 @@ import asyncio ...@@ -2,8 +2,8 @@ import asyncio
import copy import copy
import pickle import pickle
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Union) Optional, Union, cast, overload)
import cloudpickle import cloudpickle
import zmq import zmq
...@@ -13,9 +13,12 @@ from zmq.asyncio import Socket ...@@ -13,9 +13,12 @@ from zmq.asyncio import Socket
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T, IPC_OUTPUT_EXT, RPC_REQUEST_T,
...@@ -23,15 +26,18 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -23,15 +26,18 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCError, RPCProcessRequest, RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -47,7 +53,7 @@ class MQClientClosedError(Exception): ...@@ -47,7 +53,7 @@ class MQClientClosedError(Exception):
""" """
class MQLLMEngineClient: class MQLLMEngineClient(EngineClient):
"""A client wrapper for MQLLMEngine that conforms to the """A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol. EngineClient protocol.
...@@ -310,7 +316,7 @@ class MQLLMEngineClient: ...@@ -310,7 +316,7 @@ class MQLLMEngineClient:
or response != VLLM_RPC_SUCCESS_STR): or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message) raise ValueError(error_message)
async def get_tokenizer(self, lora_request: LoRARequest): async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request) return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
...@@ -338,8 +344,14 @@ class MQLLMEngineClient: ...@@ -338,8 +344,14 @@ class MQLLMEngineClient:
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket) request=RPCAbortRequest(request_id), socket=self.input_socket)
async def do_log_stats(self): async def do_log_stats(
"""Ignore do_log_stats (handled on MQLLMEngine polling)""" self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
"""
Ignore do_log_stats (handled on MQLLMEngine polling)
"""
pass pass
async def check_health(self): async def check_health(self):
...@@ -367,14 +379,48 @@ class MQLLMEngineClient: ...@@ -367,14 +379,48 @@ class MQLLMEngineClient:
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
return ENGINE_DEAD_ERROR(self._errored_with) return ENGINE_DEAD_ERROR(self._errored_with)
@overload # DEPRECATED
def generate( def generate(
self, self,
inputs: PromptInputs, *,
inputs: PromptType,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
@overload
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def generate(
self,
prompt: Optional[PromptType] = None,
sampling_params: Optional[SamplingParams] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -383,8 +429,7 @@ class MQLLMEngineClient: ...@@ -383,8 +429,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
...@@ -392,18 +437,58 @@ class MQLLMEngineClient: ...@@ -392,18 +437,58 @@ class MQLLMEngineClient:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
""" """
return self._process_request(inputs, sampling_params, request_id, if inputs is not None:
prompt = inputs
assert (prompt is not None and sampling_params is not None
and request_id is not None)
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
prompt_adapter_request) prompt_adapter_request, priority)
@overload # DEPRECATED
def encode(
self,
*,
inputs: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@overload
def encode( def encode(
self, self,
inputs: PromptInputs, prompt: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def encode(
self,
prompt: Optional[PromptType] = None,
pooling_params: Optional[PoolingParams] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
...@@ -412,8 +497,7 @@ class MQLLMEngineClient: ...@@ -412,8 +497,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
...@@ -424,17 +508,29 @@ class MQLLMEngineClient: ...@@ -424,17 +508,29 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request. for the request.
""" """
return self._process_request(inputs, pooling_params, request_id, if inputs is not None:
lora_request, trace_headers) prompt = inputs
assert (prompt is not None and pooling_params is not None
and request_id is not None)
return cast(
AsyncGenerator[EmbeddingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))
async def _process_request( async def _process_request(
self, self,
inputs: PromptInputs, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]: EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
...@@ -443,6 +539,20 @@ class MQLLMEngineClient: ...@@ -443,6 +539,20 @@ class MQLLMEngineClient:
if self._errored_with is not None: if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with) raise ENGINE_DEAD_ERROR(self._errored_with)
# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
)
# 1) Create output queue for this requests. # 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput, queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue() BaseException]] = asyncio.Queue()
...@@ -462,12 +572,14 @@ class MQLLMEngineClient: ...@@ -462,12 +572,14 @@ class MQLLMEngineClient:
request_bytes = pickle.dumps( request_bytes = pickle.dumps(
RPCProcessRequest( RPCProcessRequest(
inputs=inputs, prompt=prompt,
params=params, params=params,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)) prompt_adapter_request=prompt_adapter_request,
priority=priority,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine. # 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes, parts = (request_bytes,
......
...@@ -73,11 +73,9 @@ class MQLLMEngine: ...@@ -73,11 +73,9 @@ class MQLLMEngine:
# For MQLLMEngine, we can use cached outputs, since each new request # For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees # output is immediately pickled and send over the socket, which frees
# the python object to be reused again. # the python object to be reused again.
use_cached_outputs = True kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args, self.engine = LLMEngine(*args, **kwargs)
**kwargs,
use_cached_outputs=use_cached_outputs)
self.log_requests = log_requests self.log_requests = log_requests
self.use_async_sockets = use_async_sockets self.use_async_sockets = use_async_sockets
...@@ -130,6 +128,9 @@ class MQLLMEngine: ...@@ -130,6 +128,9 @@ class MQLLMEngine:
def from_engine_args(cls, engine_args: AsyncEngineArgs, def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str): usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments.""" """Creates an MQLLMEngine from the engine arguments."""
# Setup plugins for each process
from vllm.plugins import load_general_plugins
load_general_plugins()
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
...@@ -278,11 +279,12 @@ class MQLLMEngine: ...@@ -278,11 +279,12 @@ class MQLLMEngine:
try: try:
self.engine.add_request( self.engine.add_request(
request_id=request_id, request_id=request_id,
inputs=request.inputs, prompt=request.prompt,
params=request.params, params=request.params,
lora_request=request.lora_request, lora_request=request.lora_request,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request) prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
if self.log_requests: if self.log_requests:
logger.info("Added request %s.", request.request_id) logger.info("Added request %s.", request.request_id)
......
import functools import functools
from typing import Callable, List from typing import Callable, List, cast
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
...@@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import ( ...@@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
SequenceGroupOutput, SequenceOutput, SequenceStatus) CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -57,11 +59,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -57,11 +59,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
""" """
for output in outputs: for output in outputs:
# Concatenate single-step prompt logprob processing results. # Concatenate single-step prompt logprob processing results.
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output) single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod @staticmethod
@functools.lru_cache() @functools.lru_cache()
def _log_prompt_logprob_unsupported_warning_once(): def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
logger.warning( logger.warning(
"Prompt logprob is not supported by multi step workers. " "Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).") "(e.g., speculative decode uses multi step workers).")
...@@ -97,6 +102,19 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -97,6 +102,19 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
assert len(seqs) == 1, ( assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.") "Beam search not supported in multi-step decoding.")
seq = seqs[0] seq = seqs[0]
seq_id = seq.seq_id
# This method is defined in the more generic
# SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert all([
isinstance(output, CompletionSequenceGroupOutput)
for output in outputs
])
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
assert all([
seq_id == output.samples[0].parent_seq_id
for output in compl_outputs
])
if is_async: if is_async:
# Async case: We process tokens one by one. Here, we know the token # Async case: We process tokens one by one. Here, we know the token
...@@ -108,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -108,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group, # Since there's only one sequence per sequence group,
# we can take the first sample. # we can take the first sample.
samples = [output.samples[0] for output in outputs] samples = [output.samples[0] for output in compl_outputs]
# entries in sample tokens may be invalid (eg. due to spec decode # entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens). # rejecting tokens).
...@@ -145,7 +163,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -145,7 +163,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids)) len(output_token_ids))
if remaining_tokens < 0: if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens] output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode # Truncate any tokens after EOS. This is required as spec decode
...@@ -159,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -159,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
for i in range(len(output_token_ids)): for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id: if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1] output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break break
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new # Incrementally append tokens to the sequence, as if we had only one new
# token. # token.
for output_token_id, output_logprob in zip(output_token_ids, for output_token_id, output_logprob in zip(output_token_ids,
...@@ -171,6 +188,13 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -171,6 +188,13 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs=output_logprob, logprobs=output_logprob,
) )
if is_prefill_sampled_token:
is_prefill_sampled_token = False
else:
# Update num_computed_tokens iff the sampled token is not from
# a prefill step.
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params) self._process_decode_and_stop(seq, sampling_params)
if seq.is_finished(): if seq.is_finished():
......
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Tuple
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
...@@ -6,9 +6,9 @@ from vllm.engine.output_processor.interfaces import ( ...@@ -6,9 +6,9 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, Sequence,
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceOutput, SequenceStatus) SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -17,7 +17,7 @@ logger = init_logger(__name__) ...@@ -17,7 +17,7 @@ logger = init_logger(__name__)
def single_step_process_prompt_logprob( def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: SequenceGroupOutput) -> None: output: CompletionSequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput` """Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step. for a given step.
...@@ -107,13 +107,14 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -107,13 +107,14 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
""" """
assert len(outputs) == 1, ("Single step should only has 1 output.") assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0] output = outputs[0]
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output) single_step_process_prompt_logprob(self, seq_group, output)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput, outputs: SequenceGroupOutput,
is_async: bool) -> None: is_async: bool) -> None:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
if sampling_params.best_of == 1 and not sampling_params.use_beam_search: if sampling_params.n == 1:
# only have one output sample # only have one output sample
sample = outputs.samples[0] sample = outputs.samples[0]
# only have one sequence # only have one sequence
...@@ -142,7 +143,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -142,7 +143,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Process samples # Process samples
samples = outputs.samples samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = { parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: [] parent_seq.seq_id: []
for parent_seq in parent_seqs for parent_seq in parent_seqs
...@@ -197,106 +197,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -197,106 +197,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
lora_req=seq_group.lora_request, lora_req=seq_group.lora_request,
) )
# Non-beam search case
if not sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = sampling_params.best_of
length_penalty = sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
sampling_params.early_stopping, sampling_params,
best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group # For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished. # and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs: for seq, parent in child_seqs:
if seq is not parent: if seq is not parent:
seq_group.add(seq) seq_group.add(seq)
if not seq.is_finished(): if not seq.is_finished():
...@@ -305,61 +208,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -305,61 +208,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Free the finished and selected parent sequences' memory in block # Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output. # manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs: # NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished(): if seq is parent and seq.is_finished():
for scheduler in self.scheduler: for scheduler in self.scheduler:
scheduler.free_seq(seq) scheduler.free_seq(seq)
return
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score
...@@ -57,7 +57,7 @@ class StopChecker: ...@@ -57,7 +57,7 @@ class StopChecker:
# Check if a stop token was encountered. # Check if a stop token was encountered.
# This assumes a single token produced per step. # This assumes a single token produced per step.
last_token_id = seq.get_last_token_id() last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids: if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and ( if new_char_count and (
not sampling_params.include_stop_str_in_output): not sampling_params.include_stop_str_in_output):
# Remove last token # Remove last token
...@@ -92,7 +92,7 @@ class StopChecker: ...@@ -92,7 +92,7 @@ class StopChecker:
Returns the stop string if matched or else None. Returns the stop string if matched or else None.
""" """
if not new_char_count: if not new_char_count or not sampling_params.stop:
return None return None
for stop_str in sampling_params.stop: for stop_str in sampling_params.stop:
......
from typing import List from typing import List
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import cast
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import PoolerOutput, SequenceGroupOutput from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
def create_output_by_sequence_group( def create_output_by_sequence_group(
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], outputs: GenericSequence[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]: num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by """Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step]. [step][sequence group] into [sequence group][step].
""" """
output_by_sequence_group: List[List[SequenceGroupOutput]] = [ output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
[] for _ in range(num_seq_groups) [] for _ in range(num_seq_groups)
] ]
for step in outputs: for step in outputs:
sequence_group_output: CompletionSequenceGroupOutput
for i, sequence_group_output in enumerate(step): for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output) output_by_sequence_group[i].append(sequence_group_output)
return output_by_sequence_group # Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, import asyncio
runtime_checkable) from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid
logger = init_logger(__name__)
@runtime_checkable
class EngineClient(Protocol): class EngineClient(ABC):
"""Protocol class for Clients to Engine""" """Protocol class for Clients to Engine"""
@property @property
@abstractmethod
def is_running(self) -> bool: def is_running(self) -> bool:
... ...
@property @property
@abstractmethod
def is_stopped(self) -> bool: def is_stopped(self) -> bool:
... ...
@property @property
@abstractmethod
def errored(self) -> bool: def errored(self) -> bool:
... ...
@property @property
@abstractmethod
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
... ...
@abstractmethod
def generate( def generate(
self, self,
inputs: PromptInputs, prompt: PromptType,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request""" """Generate outputs for a request."""
... ...
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []
for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)
yield beam_search_output
@abstractmethod
def encode( def encode(
self, self,
inputs: PromptInputs, prompt: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.""" """Generate outputs for a request from an embedding model."""
... ...
@abstractmethod
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request. """Abort a request.
...@@ -63,14 +172,17 @@ class EngineClient(Protocol): ...@@ -63,14 +172,17 @@ class EngineClient(Protocol):
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
@abstractmethod
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
... ...
@abstractmethod
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
... ...
"""Get the decoding configuration of the vLLM engine.""" """Get the decoding configuration of the vLLM engine."""
@abstractmethod
async def get_tokenizer( async def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -78,9 +190,11 @@ class EngineClient(Protocol): ...@@ -78,9 +190,11 @@ class EngineClient(Protocol):
"""Get the appropriate tokenizer for the request""" """Get the appropriate tokenizer for the request"""
... ...
@abstractmethod
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
... ...
@abstractmethod
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
...@@ -88,14 +202,17 @@ class EngineClient(Protocol): ...@@ -88,14 +202,17 @@ class EngineClient(Protocol):
) -> None: ) -> None:
... ...
@abstractmethod
async def check_health(self) -> None: async def check_health(self) -> None:
"""Raise if unhealthy""" """Raise if unhealthy"""
... ...
@abstractmethod
async def start_profile(self) -> None: async def start_profile(self) -> None:
"""Start profiling the engine""" """Start profiling the engine"""
... ...
@abstractmethod
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
"""Start profiling the engine""" """Start profiling the engine"""
... ...
...@@ -157,22 +157,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -157,22 +157,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"): if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer, return self._cached_token_str(self._tokenizer,
hf_config.image_token_index) hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"): if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
return "<image>" return "<image>"
if model_type == "mllama": if model_type == "mllama":
return "<|image|>" return "<|image|>"
if model_type == "qwen2_vl": if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>" return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type == "ultravox":
return "<|reserved_special_token_0|>" return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "video": elif modality == "video":
if model_type == "qwen2_vl": if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>" return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
else: else:
raise TypeError(f"Unknown modality: {modality}") raise TypeError(f"Unknown modality: {modality}")
...@@ -303,6 +305,28 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -303,6 +305,28 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
return
elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError(
"the supplied chat template path doesn't exist")
elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n"
if not any(c in chat_template
for c in JINJA_CHARS) and not Path(chat_template).exists():
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!")
else:
raise TypeError(
f"{type(chat_template)} is not a valid chat template type")
def load_chat_template( def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]: chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None: if chat_template is None:
...@@ -542,6 +566,14 @@ def apply_mistral_chat_template( ...@@ -542,6 +566,14 @@ def apply_mistral_chat_template(
if chat_template is not None: if chat_template is not None:
logger.warning( logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.") "'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
messages=messages, messages=messages,
......
import itertools import itertools
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload) Union, cast, overload)
from tqdm import tqdm from tqdm import tqdm
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template, apply_hf_chat_template,
apply_mistral_chat_template, apply_mistral_chat_template,
parse_chat_messages) parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor) GuidedDecodingRequest, LLMGuidedOptions)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
...@@ -32,37 +34,6 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of ...@@ -32,37 +34,6 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None
@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]
class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []
class LLM: class LLM:
"""An LLM for generating texts from given prompts and sampling parameters. """An LLM for generating texts from given prompts and sampling parameters.
...@@ -179,15 +150,7 @@ class LLM: ...@@ -179,15 +150,7 @@ class LLM:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
removed_vision_keys = (
"image_token_id",
"image_feature_size",
"image_input_shape",
"image_input_type",
)
if any(k in kwargs for k in removed_vision_keys):
raise TypeError(
"There is no need to pass vision-related arguments anymore.")
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -293,8 +256,8 @@ class LLM: ...@@ -293,8 +256,8 @@ class LLM:
@overload @overload
def generate( def generate(
self, self,
inputs: Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[PromptType, Sequence[PromptType]],
/, # We may enable `inputs` keyword after removing the old API /,
*, *,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
...@@ -304,14 +267,13 @@ class LLM: ...@@ -304,14 +267,13 @@ class LLM:
... ...
@deprecate_kwargs( @deprecate_kwargs(
"prompts",
"prompt_token_ids", "prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.", additional_message="Please use the 'prompts' parameter instead.",
) )
def generate( def generate(
self, self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None, Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
...@@ -330,7 +292,9 @@ class LLM: ...@@ -330,7 +292,9 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
inputs: A list of inputs to generate completions for. prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt. When it is a single value, it is applied to every prompt.
...@@ -358,12 +322,13 @@ class LLM: ...@@ -358,12 +322,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).") "models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None: if prompt_token_ids is not None:
inputs = self._convert_v1_inputs( parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts), prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) )
else: else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if isinstance(guided_options_request, dict): if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1: if len(guided_options_request) > 1:
...@@ -378,7 +343,7 @@ class LLM: ...@@ -378,7 +343,7 @@ class LLM:
sampling_params = SamplingParams() sampling_params = SamplingParams()
self._validate_and_add_requests( self._validate_and_add_requests(
inputs=inputs, prompts=parsed_prompts,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
...@@ -391,9 +356,7 @@ class LLM: ...@@ -391,9 +356,7 @@ class LLM:
def beam_search( def beam_search(
self, self,
prompts: List[Union[str, List[int]]], prompts: List[Union[str, List[int]]],
beam_width: int, params: BeamSearchParams,
max_tokens: int,
ignore_eos: bool = False,
) -> List[BeamSearchOutput]: ) -> List[BeamSearchOutput]:
""" """
Generate sequences using beam search. Generate sequences using beam search.
...@@ -401,20 +364,30 @@ class LLM: ...@@ -401,20 +364,30 @@ class LLM:
Args: Args:
prompts: A list of prompts. Each prompt can be a string or a list prompts: A list of prompts. Each prompt can be a string or a list
of token IDs. of token IDs.
beam_width: The number of beams to keep at each step. params: The beam search parameters.
max_tokens: The max number of tokens to generate for each prompt.
TODO: how does beam search work together with length penalty, frequency TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.? penalty, and stopping criteria, etc.?
""" """
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation # following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(logprobs=2 * beam_width, beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1, max_tokens=1,
temperature=0.0) temperature=temperature)
instances: List[BeamSearchInstance] = [] instances: List[BeamSearchInstance] = []
for prompt in prompts: for prompt in prompts:
...@@ -469,7 +442,7 @@ class LLM: ...@@ -469,7 +442,7 @@ class LLM:
else: else:
instance_new_beams.append(new_beam) instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams, sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob, key=sort_beams_key,
reverse=True) reverse=True)
instance.beams = sorted_beams[:beam_width] instance.beams = sorted_beams[:beam_width]
...@@ -477,7 +450,7 @@ class LLM: ...@@ -477,7 +450,7 @@ class LLM:
for instance in instances: for instance in instances:
instance.completed.extend(instance.beams) instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed, sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob, key=sort_beams_key,
reverse=True) reverse=True)
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
...@@ -497,7 +470,9 @@ class LLM: ...@@ -497,7 +470,9 @@ class LLM:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None, tools: Optional[List[Dict[str, Any]]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
""" """
Generate responses for a chat conversation. Generate responses for a chat conversation.
...@@ -524,6 +499,11 @@ class LLM: ...@@ -524,6 +499,11 @@ class LLM:
If not provided, the model's default chat template will be used. If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template add_generation_prompt: If True, adds a generation template
to each message. to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns: Returns:
A list of ``RequestOutput`` objects containing the generated A list of ``RequestOutput`` objects containing the generated
...@@ -534,10 +514,13 @@ class LLM: ...@@ -534,10 +514,13 @@ class LLM:
# Handle multi and single conversations # Handle multi and single conversations
if is_list_of(messages, list): if is_list_of(messages, list):
# messages is List[List[...]] # messages is List[List[...]]
list_of_messages = messages list_of_messages = cast(List[List[ChatCompletionMessageParam]],
messages)
else: else:
# messages is List[...] # messages is List[...]
list_of_messages = [messages] list_of_messages = [
cast(List[ChatCompletionMessageParam], messages)
]
prompts: List[Union[TokensPrompt, TextPrompt]] = [] prompts: List[Union[TokensPrompt, TextPrompt]] = []
...@@ -545,6 +528,9 @@ class LLM: ...@@ -545,6 +528,9 @@ class LLM:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config() model_config = self.llm_engine.get_model_config()
# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data = parse_chat_messages( conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer) msgs, model_config, tokenizer)
...@@ -555,6 +541,7 @@ class LLM: ...@@ -555,6 +541,7 @@ class LLM:
messages=msgs, messages=msgs,
chat_template=chat_template, chat_template=chat_template,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools, tools=tools,
) )
else: else:
...@@ -563,6 +550,7 @@ class LLM: ...@@ -563,6 +550,7 @@ class LLM:
conversation=conversation, conversation=conversation,
chat_template=chat_template, chat_template=chat_template,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools, tools=tools,
) )
...@@ -575,6 +563,9 @@ class LLM: ...@@ -575,6 +563,9 @@ class LLM:
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs
prompts.append(prompt) prompts.append(prompt)
return self.generate( return self.generate(
...@@ -648,8 +639,8 @@ class LLM: ...@@ -648,8 +639,8 @@ class LLM:
@overload @overload
def encode( def encode(
self, self,
inputs: Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[PromptType, Sequence[PromptType]],
/, # We may enable `inputs` keyword after removing the old API /,
*, *,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
...@@ -659,14 +650,13 @@ class LLM: ...@@ -659,14 +650,13 @@ class LLM:
... ...
@deprecate_kwargs( @deprecate_kwargs(
"prompts",
"prompt_token_ids", "prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.", additional_message="Please use the 'prompts' parameter instead.",
) )
def encode( def encode(
self, self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None, Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
...@@ -682,9 +672,9 @@ class LLM: ...@@ -682,9 +672,9 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for prompts: The prompts to the LLM. You may pass a sequence of prompts
batch inference. See :class:`~vllm.inputs.PromptInputs` for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input. for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters. use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
...@@ -707,19 +697,20 @@ class LLM: ...@@ -707,19 +697,20 @@ class LLM:
) )
if prompt_token_ids is not None: if prompt_token_ids is not None:
inputs = self._convert_v1_inputs( parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts), prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) )
else: else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
self._validate_and_add_requests( self._validate_and_add_requests(
inputs=inputs, prompts=parsed_prompts,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
...@@ -763,9 +754,9 @@ class LLM: ...@@ -763,9 +754,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be " raise ValueError("Either prompts or prompt_token_ids must be "
"provided.") "provided.")
inputs: List[PromptInputs] = [] parsed_prompts: List[PromptType] = []
for i in range(num_requests): for i in range(num_requests):
item: PromptInputs item: PromptType
if prompts is not None: if prompts is not None:
item = TextPrompt(prompt=prompts[i]) item = TextPrompt(prompt=prompts[i])
...@@ -774,13 +765,13 @@ class LLM: ...@@ -774,13 +765,13 @@ class LLM:
else: else:
raise AssertionError raise AssertionError
inputs.append(item) parsed_prompts.append(item)
return inputs return parsed_prompts
def _validate_and_add_requests( def _validate_and_add_requests(
self, self,
inputs: Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
...@@ -788,11 +779,19 @@ class LLM: ...@@ -788,11 +779,19 @@ class LLM:
guided_options: Optional[GuidedDecodingRequest] = None, guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None, priority: Optional[List[int]] = None,
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if guided_options is not None:
warnings.warn(
"guided_options_request is deprecated, use "
"SamplingParams.guided_decoding instead",
DeprecationWarning,
stacklevel=2,
)
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
inputs = [inputs] prompts = [prompts]
num_requests = len(inputs) num_requests = len(prompts)
if isinstance(params, list) and len(params) != num_requests: if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params " raise ValueError("The lengths of prompts and params "
"must be the same.") "must be the same.")
...@@ -803,15 +802,15 @@ class LLM: ...@@ -803,15 +802,15 @@ class LLM:
for sp in params if isinstance(params, list) else (params, ): for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams): if isinstance(sp, SamplingParams):
self._add_guided_processor(sp, guided_options) self._add_guided_params(sp, guided_options)
# We only care about the final output # We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine. # Add requests to the engine.
for i, request_inputs in enumerate(inputs): for i, prompt in enumerate(prompts):
self._add_request( self._add_request(
request_inputs, prompt,
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
...@@ -821,7 +820,7 @@ class LLM: ...@@ -821,7 +820,7 @@ class LLM:
def _add_request( def _add_request(
self, self,
inputs: PromptInputs, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
...@@ -830,29 +829,32 @@ class LLM: ...@@ -830,29 +829,32 @@ class LLM:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request( self.llm_engine.add_request(
request_id, request_id,
inputs, prompt,
params, params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
) )
def _add_guided_processor( def _add_guided_params(
self, self,
params: SamplingParams, params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None): guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options: if guided_options is None:
if guided_options.guided_decoding_backend is None: return params
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = ( if params.guided_decoding is not None:
decoding_config.guided_decoding_backend) raise ValueError("Cannot set both guided_options_request and"
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa "params.guided_decoding.")
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer()) params.guided_decoding = GuidedDecodingParams(
if guided_logits_processor: json=guided_options.guided_json,
if params.logits_processors is None: regex=guided_options.guided_regex,
params.logits_processors = [] choice=guided_options.guided_choice,
params.logits_processors.append(guided_logits_processor) grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern)
return params return params
def _run_engine( def _run_engine(
......
...@@ -4,7 +4,7 @@ from vllm.logger import init_logger ...@@ -4,7 +4,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -21,7 +21,8 @@ class RequestLogger: ...@@ -21,7 +21,8 @@ class RequestLogger:
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
params: Optional[Union[SamplingParams, PoolingParams]], params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:
......
...@@ -31,7 +31,8 @@ from vllm.engine.multiprocessing.engine import run_mp_engine ...@@ -31,7 +31,8 @@ from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...@@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding ...@@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
...@@ -526,8 +528,20 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -526,8 +528,20 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args) logger.info("args: %s", args)
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
temp_socket.bind(("", args.port)) ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", args.port))
def signal_handler(*_) -> None: def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing # Interrupt server on sigterm while initializing
...@@ -541,8 +555,6 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -541,8 +555,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
model_config = await engine_client.get_model_config() model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args) init_app_state(engine_client, model_config, app.state, args)
temp_socket.close()
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
host=args.host, host=args.host,
...@@ -553,6 +565,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -553,6 +565,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
ssl_certfile=args.ssl_certfile, ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs, ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs, ssl_cert_reqs=args.ssl_cert_reqs,
fd=sock.fileno(),
**uvicorn_kwargs, **uvicorn_kwargs,
) )
...@@ -567,5 +580,6 @@ if __name__ == "__main__": ...@@ -567,5 +580,6 @@ if __name__ == "__main__":
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
args = parser.parse_args() args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args)) uvloop.run(run_server(args))
...@@ -10,8 +10,10 @@ import ssl ...@@ -10,8 +10,10 @@ import ssl
from typing import List, Optional, Sequence, Union from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath) PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -190,16 +192,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -190,16 +192,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Enable auto tool choice for supported models. Use --tool-call-parser" "Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use") "to specify which parser to use")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument( parser.add_argument(
"--tool-call-parser", "--tool-call-parser",
type=str, type=str,
choices=["mistral", "hermes"], metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
"--tool-parser-plugin",
default=None, default=None,
help= help=
"Select the tool call parser depending on the model that you're using." "Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API " " This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.") "format. Required for --enable-auto-tool-choice.")
parser.add_argument(
"--tool-parser-plugin",
type=str,
default="",
help=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len', parser.add_argument('--max-log-len',
...@@ -219,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -219,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
return parser return parser
def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)
# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
def create_parser_for_docs() -> FlexibleArgumentParser: def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser( parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server") prog="-m vllm.entrypoints.openai.api_server")
......
...@@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator ...@@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
SamplingParams) RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
# torch is mocked during docs generation, # torch is mocked during docs generation,
...@@ -186,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -186,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_p: float = 0.0 min_p: float = 0.0
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
length_penalty: float = 1.0 length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
ignore_eos: bool = False ignore_eos: bool = False
...@@ -211,6 +208,15 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -211,6 +208,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"This is a parameter used by chat template in tokenizer config of the " "This is a parameter used by chat template in tokenizer config of the "
"model."), "model."),
) )
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=False, default=False,
description=( description=(
...@@ -272,13 +278,33 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -272,13 +278,33 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_sampling_params( def to_beam_search_params(self,
self, tokenizer: AnyTokenizer, default_max_tokens: int) -> BeamSearchParams:
guided_decode_logits_processor: Optional[LogitsProcessor], max_tokens = self.max_tokens
default_max_tokens: int) -> SamplingParams: if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
...@@ -287,14 +313,19 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -287,14 +313,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo: if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs. guided_json_object = None
logits_processors = get_logits_processors( if (self.response_format is not None
logit_bias=self.logit_bias, and self.response_format.type == "json_object"):
allowed_token_ids=None, guided_json_object = True
tokenizer=tokenizer,
) guided_decoding = GuidedDecodingParams.from_optional(
if guided_decode_logits_processor: json=self._get_guided_json_from_tool() or self.guided_json,
logits_processors.append(guided_decode_logits_processor) regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
...@@ -314,17 +345,32 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -314,17 +345,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
) guided_decoding=guided_decoding,
logit_bias=self.logit_bias)
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
# user has chosen to not use any tool
if self.tool_choice == "none" or self.tools is None:
return None
# user has chosen to use a named tool
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = self.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in self.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
return tool.parameters
return None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
...@@ -386,7 +432,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -386,7 +432,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# if "tool_choice" is not specified but tools are provided, # if "tool_choice" is not specified but tools are provided,
# default to "auto" tool_choice # default to "auto" tool_choice
if "tool_choice" not in data and "tools" in data: if "tool_choice" not in data and data.get("tools"):
data["tool_choice"] = "auto" data["tool_choice"] = "auto"
# if "tool_choice" is specified -- validation # if "tool_choice" is specified -- validation
...@@ -431,6 +477,15 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -431,6 +477,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
" of the specified `tools`") " of the specified `tools`")
return data return data
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
class CompletionRequest(OpenAIBaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
...@@ -460,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -460,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel):
min_p: float = 0.0 min_p: float = 0.0
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
length_penalty: float = 1.0 length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
ignore_eos: bool = False ignore_eos: bool = False
...@@ -516,13 +570,33 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -516,13 +570,33 @@ class CompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-completion-extra-params # doc: end-completion-extra-params
def to_sampling_params( def to_beam_search_params(self,
self, tokenizer: AnyTokenizer, default_max_tokens: int) -> BeamSearchParams:
guided_decode_logits_processor: Optional[LogitsProcessor], max_tokens = self.max_tokens
default_max_tokens: int) -> SamplingParams: if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
...@@ -533,13 +607,19 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -533,13 +607,19 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors( guided_json_object = None
logit_bias=self.logit_bias, if (self.response_format is not None
allowed_token_ids=self.allowed_token_ids, and self.response_format.type == "json_object"):
tokenizer=tokenizer, guided_json_object = True
)
if guided_decode_logits_processor: guided_decoding = GuidedDecodingParams.from_optional(
logits_processors.append(guided_decode_logits_processor) json=self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
...@@ -558,18 +638,16 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -558,18 +638,16 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1, max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
) guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
...@@ -619,12 +697,23 @@ class EmbeddingRequest(OpenAIBaseModel): ...@@ -619,12 +697,23 @@ class EmbeddingRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float" encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None dimensions: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: begin-embedding-pooling-params # doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params # doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-embedding-extra-params
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(additional_data=self.additional_data)
...@@ -862,8 +951,18 @@ class TokenizeChatRequest(OpenAIBaseModel): ...@@ -862,8 +951,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
messages: List[ChatCompletionMessageParam] messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True) add_generation_prompt: bool = Field(default=True)
continue_final_message: bool = Field(default=False)
add_special_tokens: bool = Field(default=False) add_special_tokens: bool = Field(default=False)
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
......
...@@ -29,12 +29,11 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath, ...@@ -29,12 +29,11 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
OpenAIServing, OpenAIServing,
PromptAdapterPath, PromptAdapterPath,
TextTokensPrompt) TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
MistralToolParser,
ToolParser)
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
...@@ -81,13 +80,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -81,13 +80,13 @@ class OpenAIServingChat(OpenAIServing):
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools: if self.enable_auto_tools:
if tool_parser == "mistral": try:
self.tool_parser = MistralToolParser self.tool_parser = ToolParserManager.get_tool_parser(
elif tool_parser == "hermes": tool_parser)
self.tool_parser = Hermes2ProToolParser except Exception as e:
else:
raise TypeError("Error: --enable-auto-tool-choice requires " raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser") f"tool_parser:'{tool_parser}' which has not "
"been registered") from e
async def create_chat_completion( async def create_chat_completion(
self, self,
...@@ -137,6 +136,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -137,6 +136,7 @@ class OpenAIServingChat(OpenAIServing):
messages=request.messages, messages=request.messages,
chat_template=request.chat_template or self.chat_template, chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts, tools=tool_dicts,
documents=request.documents, documents=request.documents,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
...@@ -147,18 +147,19 @@ class OpenAIServingChat(OpenAIServing): ...@@ -147,18 +147,19 @@ class OpenAIServingChat(OpenAIServing):
conversation=conversation, conversation=conversation,
chat_template=request.chat_template or self.chat_template, chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts, tools=tool_dicts,
documents=request.documents, documents=request.documents,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
) )
except Exception as e: except Exception as e:
logger.error("Error in applying chat template from request: %s", e) logger.exception("Error in applying chat template from request")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
try: try:
mm_data = await mm_data_future mm_data = await mm_data_future
except Exception as e: except Exception as e:
logger.error("Error in loading multi-modal data: %s", e) logger.exception("Error in loading multi-modal data")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
# validation for OpenAI tools # validation for OpenAI tools
...@@ -182,8 +183,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -182,8 +183,9 @@ class OpenAIServingChat(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: try:
guided_decode_logits_processor = ( if self.enable_auto_tools and self.tool_parser:
await self._guided_decode_logits_processor(request, tokenizer)) request = self.tool_parser(tokenizer).adjust_request(
request=request)
if isinstance(prompt, str): if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input( prompt_inputs = self._tokenize_prompt_input(
...@@ -202,11 +204,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -202,11 +204,15 @@ class OpenAIServingChat(OpenAIServing):
assert prompt_inputs is not None assert prompt_inputs is not None
sampling_params = request.to_sampling_params( sampling_params: Union[SamplingParams, BeamSearchParams]
tokenizer, default_max_tokens = self.max_model_len - len(
guided_decode_logits_processor, prompt_inputs["prompt_token_ids"])
default_max_tokens=self.max_model_len - if request.use_beam_search:
len(prompt_inputs["prompt_token_ids"])) sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
self._log_inputs(request_id, self._log_inputs(request_id,
prompt_inputs, prompt_inputs,
...@@ -228,14 +234,22 @@ class OpenAIServingChat(OpenAIServing): ...@@ -228,14 +234,22 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)): and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning() log_tracing_disabled_warning()
result_generator = self.engine_client.generate( if isinstance(sampling_params, BeamSearchParams):
engine_inputs, result_generator = self.engine_client.beam_search(
sampling_params, engine_inputs['prompt_token_ids'],
request_id, request_id,
lora_request=lora_request, sampling_params,
trace_headers=trace_headers, )
prompt_adapter_request=prompt_adapter_request, else:
) result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -281,12 +295,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -281,12 +295,8 @@ class OpenAIServingChat(OpenAIServing):
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_num_tokens = [0] * num_choices previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0 num_prompt_tokens = 0
tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name tool_choice_function_name = request.tool_choice.function.name
else: else:
...@@ -305,6 +315,29 @@ class OpenAIServingChat(OpenAIServing): ...@@ -305,6 +315,29 @@ class OpenAIServingChat(OpenAIServing):
else: else:
previous_texts, all_previous_token_ids = None, None previous_texts, all_previous_token_ids = None, None
# Prepare the tool parser if it's needed
try:
if tool_choice_auto and self.tool_parser:
tool_parsers: List[Optional[ToolParser]] = [
self.tool_parser(tokenizer)
] * num_choices
else:
tool_parsers = [None] * num_choices
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False
try: try:
async for res in result_generator: async for res in result_generator:
if res.prompt_token_ids is not None: if res.prompt_token_ids is not None:
...@@ -323,7 +356,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -323,7 +356,6 @@ class OpenAIServingChat(OpenAIServing):
# NOTE num_choices defaults to 1 so this usually executes # NOTE num_choices defaults to 1 so this usually executes
# once per request # once per request
for i in range(num_choices): for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage( delta=DeltaMessage(
...@@ -339,26 +371,19 @@ class OpenAIServingChat(OpenAIServing): ...@@ -339,26 +371,19 @@ class OpenAIServingChat(OpenAIServing):
choices=[choice_data], choices=[choice_data],
model=model_name) model=model_name)
# if usage should be included # if continuous usage stats are requested, add it
if (request.stream_options if include_continuous_usage:
and request.stream_options.include_usage): chunk.usage = UsageInfo(
# if continuous usage stats are requested, add it prompt_tokens=num_prompt_tokens,
if request.stream_options.continuous_usage_stats: completion_tokens=0,
usage = UsageInfo( total_tokens=num_prompt_tokens)
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Send response to echo the input portion of the # Send response to echo the input portion of the
# last message # last message
if request.echo: if request.echo or request.continue_final_message:
last_msg_content: str = "" last_msg_content: str = ""
if conversation and "content" in conversation[ if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role: -1] and conversation[-1].get("role") == role:
...@@ -379,17 +404,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -379,17 +404,11 @@ class OpenAIServingChat(OpenAIServing):
created=created_time, created=created_time,
choices=[choice_data], choices=[choice_data],
model=model_name) model=model_name)
if (request.stream_options and if include_continuous_usage:
request.stream_options.include_usage): chunk.usage = UsageInfo(
if (request.stream_options. prompt_tokens=num_prompt_tokens,
continuous_usage_stats): completion_tokens=0,
usage = UsageInfo( total_tokens=num_prompt_tokens)
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json( data = chunk.model_dump_json(
exclude_unset=True) exclude_unset=True)
...@@ -398,6 +417,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -398,6 +417,7 @@ class OpenAIServingChat(OpenAIServing):
for output in res.outputs: for output in res.outputs:
i = output.index i = output.index
tool_parser = tool_parsers[i]
if finish_reason_sent[i]: if finish_reason_sent[i]:
continue continue
...@@ -415,6 +435,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -415,6 +435,12 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None logprobs = None
delta_text = output.text delta_text = output.text
if not delta_text and not output.token_ids and \
not previous_num_tokens[i]:
# Chunked prefill case, don't return empty chunks
continue
delta_message: Optional[DeltaMessage] delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
...@@ -445,7 +471,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -445,7 +471,8 @@ class OpenAIServingChat(OpenAIServing):
delta_text=delta_text, delta_text=delta_text,
previous_token_ids=previous_token_ids, previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids, current_token_ids=current_token_ids,
delta_token_ids=output.token_ids)) delta_token_ids=output.token_ids,
request=request))
# update the previous values for the next iteration # update the previous values for the next iteration
previous_texts[i] = current_text previous_texts[i] = current_text
...@@ -467,36 +494,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -467,36 +494,11 @@ class OpenAIServingChat(OpenAIServing):
if output.finish_reason is None: if output.finish_reason is None:
# Send token-by-token response for each request.n # Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=delta_message, delta=delta_message,
logprobs=logprobs, logprobs=logprobs,
finish_reason=None) finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# if the model is finished generating # if the model is finished generating
else: else:
...@@ -504,10 +506,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -504,10 +506,12 @@ class OpenAIServingChat(OpenAIServing):
# any tokens that were generated but previously # any tokens that were generated but previously
# matched by partial json parsing # matched by partial json parsing
# only happens if we are NOT using guided decoding # only happens if we are NOT using guided decoding
auto_tools_called = False
if tool_parser: if tool_parser:
index = len( auto_tools_called = len(
tool_parser.prev_tool_call_arr) - 1 if len( tool_parser.prev_tool_call_arr) > 0
tool_parser.prev_tool_call_arr) > 0 else 0 index = len(tool_parser.prev_tool_call_arr
) - 1 if auto_tools_called else 0
else: else:
index = 0 index = 0
...@@ -542,38 +546,34 @@ class OpenAIServingChat(OpenAIServing): ...@@ -542,38 +546,34 @@ class OpenAIServingChat(OpenAIServing):
delta=delta_message, delta=delta_message,
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason finish_reason=output.finish_reason
if not (tool_parser if not auto_tools_called else "tool_calls",
and len(tool_parser.prev_tool_call_arr))
else "tool_calls",
stop_reason=output.stop_reason) stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if include_continuous_usage:
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# once the final token is handled, if stream_options.include_usage # once the final token is handled, if stream_options.include_usage
# is sent, send the usage # is sent, send the usage
if (request.stream_options if include_usage:
and request.stream_options.include_usage): completion_tokens = sum(previous_num_tokens)
completion_tokens = previous_num_tokens[i]
final_usage = UsageInfo( final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
...@@ -600,7 +600,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -600,7 +600,7 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e) logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished # Send the final done message after all response.n are finished
...@@ -646,8 +646,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -646,8 +646,10 @@ class OpenAIServingChat(OpenAIServing):
else: else:
logprobs = None logprobs = None
# by default, tools are not used. # In the OpenAI API the finish_reason is "tools_called"
tools_called = False # if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = False
# if auto tools are not enabled, and a named tool choice using # if auto tools are not enabled, and a named tool choice using
# outlines is not being used # outlines is not being used
...@@ -669,7 +671,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -669,7 +671,6 @@ class OpenAIServingChat(OpenAIServing):
name=request.tool_choice.function.name, name=request.tool_choice.function.name,
arguments=output.text)) arguments=output.text))
]) ])
tools_called = True
# if the request doesn't use tool choice # if the request doesn't use tool choice
# OR specifies to not use a tool # OR specifies to not use a tool
...@@ -683,9 +684,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -683,9 +684,18 @@ class OpenAIServingChat(OpenAIServing):
or request.tool_choice is None) and self.enable_auto_tools \ or request.tool_choice is None) and self.enable_auto_tools \
and self.tool_parser: and self.tool_parser:
tool_parser = self.tool_parser(tokenizer) try:
tool_call_info = tool_parser.extract_tool_calls(output.text) tool_parser = self.tool_parser(tokenizer)
tools_called = tool_call_info.tools_called except RuntimeError as e:
logger.exception("Error in tool parser creation.")
return self.create_error_response(str(e))
tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called
if tool_call_info.tools_called: if tool_call_info.tools_called:
message = ChatMessage(role=role, message = ChatMessage(role=role,
content=tool_call_info.content, content=tool_call_info.content,
...@@ -708,12 +718,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -708,12 +718,12 @@ class OpenAIServingChat(OpenAIServing):
index=output.index, index=output.index,
message=message, message=message,
logprobs=logprobs, logprobs=logprobs,
finish_reason="tool_calls" if tools_called else finish_reason="tool_calls" if auto_tools_called else
output.finish_reason if output.finish_reason else "stop", output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason) stop_reason=output.stop_reason)
choices.append(choice_data) choices.append(choice_data)
if request.echo: if request.echo or request.continue_final_message:
last_msg_content = "" last_msg_content = ""
if conversation and "content" in conversation[-1] and conversation[ if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role: -1].get("role") == role:
...@@ -726,6 +736,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -726,6 +736,8 @@ class OpenAIServingChat(OpenAIServing):
assert final_res.prompt_token_ids is not None assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
num_generated_tokens = sum( num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
......
...@@ -28,6 +28,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath, ...@@ -28,6 +28,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
PromptAdapterPath) PromptAdapterPath)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
...@@ -110,8 +111,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -110,8 +111,6 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompts = list( prompts = list(
self._tokenize_prompt_input_or_inputs( self._tokenize_prompt_input_or_inputs(
request, request,
...@@ -122,11 +121,15 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -122,11 +121,15 @@ class OpenAIServingCompletion(OpenAIServing):
)) ))
for i, prompt_inputs in enumerate(prompts): for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params( sampling_params: Union[SamplingParams, BeamSearchParams]
tokenizer, default_max_tokens = self.max_model_len - len(
guided_decode_logits_processor, prompt_inputs["prompt_token_ids"])
default_max_tokens=self.max_model_len - if request.use_beam_search:
len(prompt_inputs["prompt_token_ids"])) sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
...@@ -145,14 +148,25 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -145,14 +148,25 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers): raw_request.headers):
log_tracing_disabled_warning() log_tracing_disabled_warning()
generator = self.engine_client.generate( if isinstance(sampling_params, BeamSearchParams):
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, generator = self.engine_client.beam_search(
sampling_params, prompt_inputs["prompt_token_ids"],
request_id_item, request_id_item,
lora_request=lora_request, sampling_params,
prompt_adapter_request=prompt_adapter_request, )
trace_headers=trace_headers, else:
) generator = self.engine_client.generate(
{
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator) generators.append(generator)
except ValueError as e: except ValueError as e:
...@@ -260,8 +274,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -260,8 +274,6 @@ class OpenAIServingCompletion(OpenAIServing):
for output in res.outputs: for output in res.outputs:
i = output.index + prompt_idx * num_choices i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
...@@ -293,6 +305,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -293,6 +305,11 @@ class OpenAIServingCompletion(OpenAIServing):
delta_token_ids = output.token_ids delta_token_ids = output.token_ids
out_logprobs = output.logprobs out_logprobs = output.logprobs
if not delta_text and not delta_token_ids \
and not previous_num_tokens[i]:
# Chunked prefill case, don't return empty chunks
continue
if request.logprobs is not None: if request.logprobs is not None:
assert out_logprobs is not None, ( assert out_logprobs is not None, (
"Did not output logprobs") "Did not output logprobs")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment