Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
......@@ -2,8 +2,8 @@ import asyncio
import time
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
List, Mapping, Optional, Set, Tuple, Type, Union, overload)
from weakref import ReferenceType
import vllm.envs as envs
......@@ -14,12 +14,15 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
......@@ -28,7 +31,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind
from vllm.utils import deprecate_kwargs, weak_bind
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
......@@ -363,11 +366,18 @@ class _AsyncLLMEngine(LLMEngine):
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True)
is_last_step=True,
is_first_step_output=is_first_step_output)
if outputs and allow_async_output_proc:
assert len(
......@@ -402,31 +412,86 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
@overload # DEPRECATED
async def add_request_async(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@overload
async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request_async(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Async version of :meth:`add_request`."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if priority != 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if arrival_time is None:
arrival_time = time.time()
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
......@@ -435,6 +500,7 @@ class _AsyncLLMEngine(LLMEngine):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
)
async def check_health_async(self) -> None:
......@@ -443,7 +509,37 @@ class _AsyncLLMEngine(LLMEngine):
self.model_executor.check_health()
class AsyncLLMEngine:
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
return sampling_params
logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(processor)
# Unset guided decoding params after constructing the lp from them
sampling_params.guided_decoding = None
return sampling_params
class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to make it
......@@ -774,16 +870,58 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way
# for backwards compatibility.
async def add_request(
@overload # DEPRECATED
def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@overload
def add_request(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
......@@ -794,26 +932,34 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if (priority != 0
and not self.engine.scheduler_config.policy == "priority"):
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
prompt=prompt,
params=params,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return stream.generator()
async def generate(
self,
inputs: PromptInputs,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
......@@ -822,8 +968,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
......@@ -831,6 +976,8 @@ class AsyncLLMEngine:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `RequestOutput` objects from the LLMEngine
......@@ -881,21 +1028,23 @@ class AsyncLLMEngine:
"""
async for output in await self.add_request(
request_id,
inputs,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)
async def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
......@@ -904,13 +1053,14 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
......@@ -959,10 +1109,11 @@ class AsyncLLMEngine:
"""
async for output in await self.add_request(
request_id,
inputs,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
......
......@@ -6,7 +6,7 @@ from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar
......@@ -25,14 +25,17 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
......@@ -41,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
......@@ -51,7 +54,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, weak_bind
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
......@@ -90,6 +93,12 @@ class OutputData(NamedTuple):
scheduler_outputs: SchedulerOutputs
is_async: bool
is_last_step: bool
# Indicates if this output is from the first step of the
# multi-step. When multi-step is disabled, this is always
# set to True.
# is_first_step_output is invalid when `outputs` has
# outputs from multiple steps.
is_first_step_output: Optional[bool]
skip: List[int]
......@@ -108,13 +117,15 @@ class SchedulerContext:
def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
is_last_step: bool):
is_last_step: bool,
is_first_step_output: Optional[bool]):
self.output_queue.append(
OutputData(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=is_async,
is_last_step=is_last_step,
is_first_step_output=is_first_step_output,
skip=[]))
......@@ -177,7 +188,7 @@ class LLMEngine:
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
return cast(_O, output)
@classmethod
def validate_outputs(
......@@ -236,10 +247,11 @@ class LLMEngine:
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
......@@ -268,8 +280,8 @@ class LLMEngine:
observability_config,
model_config.seed,
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
......@@ -277,9 +289,6 @@ class LLMEngine:
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
load_general_plugins()
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
......@@ -625,7 +634,7 @@ class LLMEngine:
def _add_processed_request(
self,
request_id: str,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
......@@ -689,16 +698,51 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
@overload # DEPRECATED
def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@overload
def add_request(
self,
request_id: str,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Add a request to the engine's request pool.
......@@ -708,8 +752,7 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
......@@ -723,7 +766,7 @@ class LLMEngine:
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create `n` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
......@@ -744,11 +787,15 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if priority > 0 and not self.scheduler_config.policy == "priority":
if priority != 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
......@@ -756,13 +803,20 @@ class LLMEngine:
arrival_time = time.time()
preprocessed_inputs = self.input_preprocessor.preprocess(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
......@@ -795,6 +849,9 @@ class LLMEngine:
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
sampling_params = self._build_logits_processors(
sampling_params, lora_request)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
......@@ -911,6 +968,45 @@ class LLMEngine:
return
def _update_num_computed_tokens_for_multi_step_prefill(
self, seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata,
is_first_step_output: Optional[bool]):
"""
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
in multi-step are submitted in a single burst.
"""
assert self.scheduler_config.is_multi_step
if not seq_group_meta.is_prompt:
# num_computed_token updates for multi-step decodes happen after
# the tokens are appended to the sequence.
return
do_update: bool = False
if self.scheduler_config.chunked_prefill_enabled:
# In multi-step + chunked-prefill case, the prompt sequences
# that are scheduled are fully processed in the first step.
do_update = is_first_step_output is None or is_first_step_output
else:
# Normal multi-step decoding case. In this case prompt-sequences
# are actually single-stepped. Always update in this case.
assert seq_group.state.num_steps == 1
do_update = True
if do_update:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)
def _process_model_outputs(self,
ctx: SchedulerContext,
request_id: Optional[str] = None) -> None:
......@@ -919,8 +1015,8 @@ class LLMEngine:
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
"""
now = time.time()
if len(ctx.output_queue) == 0:
......@@ -931,20 +1027,28 @@ class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue[0]
is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, skip) = ctx.output_queue.popleft()
is_last_step, is_first_step_output,
skip) = ctx.output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups)
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if len(outputs) > 1:
has_multiple_outputs: bool = len(outputs) > 1
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \
self.speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list))
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output = None
else:
outputs_by_sequence_group = outputs
......@@ -974,20 +1078,26 @@ class LLMEngine:
seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
seq_group: SequenceGroup = scheduled_seq_group.seq_group
if seq_group.is_finished():
finished_before.append(i)
continue
if len(outputs) > 1:
output: List[SequenceGroupOutput]
if has_multiple_outputs:
output = outputs_by_sequence_group[i]
else:
output = [outputs_by_sequence_group[0][i]]
if not is_async:
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if self.scheduler_config.is_multi_step:
# Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size or 0)
if outputs:
for o in outputs:
......@@ -995,13 +1105,13 @@ class LLMEngine:
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time)
o.model_forward_time or 0)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time)
o.model_execute_time or 0)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
......@@ -1121,19 +1231,34 @@ class LLMEngine:
if seq_group.is_finished():
continue
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
if self.scheduler_config.is_multi_step:
# Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_metadata,
seq_group.state.num_steps == 1)
else:
token_chunk_size = (seq_group_metadata.token_chunk_size
if seq_group_metadata.token_chunk_size
is not None else 0)
seq_group.update_num_computed_tokens(token_chunk_size)
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)")
" (i.e sampling_params.n == 1)")
sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
if self.scheduler_config.is_multi_step:
is_prefill_append = seq.data.get_num_uncomputed_tokens(
) == 0
seq.append_token_id(sample.output_token, sample.logprobs)
if not is_prefill_append:
seq_group.update_num_computed_tokens(1)
else:
seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
......@@ -1286,12 +1411,19 @@ class LLMEngine:
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True)
is_last_step=True,
is_first_step_output=is_first_step_output)
if outputs and allow_async_output_proc:
assert len(outputs) == 1, (
......@@ -1482,7 +1614,6 @@ class LLMEngine:
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
best_of_requests: List[int] = []
n_requests: List[int] = []
finished_reason_requests: List[str] = []
......@@ -1553,8 +1684,6 @@ class LLMEngine:
for seq in seq_group.get_finished_seqs()
])
if seq_group.sampling_params is not None:
best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
......@@ -1607,7 +1736,6 @@ class LLMEngine:
# Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
best_of_requests=best_of_requests,
n_requests=n_requests,
finished_reason_requests=finished_reason_requests,
)
......@@ -1694,8 +1822,6 @@ class LLMEngine:
seq_group.sampling_params.top_p)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
seq_group.sampling_params.max_tokens)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
seq_group.sampling_params.best_of)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
seq_group.sampling_params.n)
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
......@@ -1732,8 +1858,8 @@ class LLMEngine:
def is_embedding_model(self):
return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
......@@ -1760,4 +1886,52 @@ class LLMEngine:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
\ No newline at end of file
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""
logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:
logger.debug(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s", guided_decoding)
tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
logits_processors.append(processor)
# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
processors = get_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
logits_processors.extend(processors)
# Unset so these don't get passed down to the model
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None
if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
else:
sampling_params.logits_processors.extend(logits_processors)
return sampling_params
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union, cast
import numpy as np
import prometheus_client
......@@ -134,12 +134,6 @@ class Metrics:
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
......@@ -255,10 +249,11 @@ class _RayHistogramWrapper:
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=buckets)
boundaries=boundaries)
def labels(self, **labels):
self._histogram.set_default_tags(labels)
......@@ -273,9 +268,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_gauge_cls = _RayGaugeWrapper
_counter_cls = _RayCounterWrapper
_histogram_cls = _RayHistogramWrapper
_gauge_cls: Type[prometheus_client.Gauge] = cast(
Type[prometheus_client.Gauge], _RayGaugeWrapper)
_counter_cls: Type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper)
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper)
def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
......@@ -473,8 +471,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
......
......@@ -49,7 +49,6 @@ class Stats:
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
......
from dataclasses import dataclass
from enum import Enum
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Union, overload
from vllm import PoolingParams
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
VLLM_RPC_SUCCESS_STR = "SUCCESS"
......@@ -23,12 +24,71 @@ class MQEngineDeadError(RuntimeError):
@dataclass
class RPCProcessRequest:
inputs: PromptInputs
prompt: PromptType
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
@overload # DEPRECATED
def __init__(
self,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@overload
def __init__(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def __init__(
self,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
if inputs is not None:
prompt = inputs
assert (prompt is not None and params is not None
and request_id is not None)
super().__init__()
self.prompt = prompt
self.params = params
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
@dataclass
......
......@@ -2,8 +2,8 @@ import asyncio
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
Union)
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, cast, overload)
import cloudpickle
import zmq
......@@ -13,9 +13,12 @@ from zmq.asyncio import Socket
from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
......@@ -23,15 +26,18 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
logger = init_logger(__name__)
......@@ -47,7 +53,7 @@ class MQClientClosedError(Exception):
"""
class MQLLMEngineClient:
class MQLLMEngineClient(EngineClient):
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
......@@ -310,7 +316,7 @@ class MQLLMEngineClient:
or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message)
async def get_tokenizer(self, lora_request: LoRARequest):
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig:
......@@ -338,8 +344,14 @@ class MQLLMEngineClient:
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket)
async def do_log_stats(self):
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
"""
Ignore do_log_stats (handled on MQLLMEngine polling)
"""
pass
async def check_health(self):
......@@ -367,14 +379,48 @@ class MQLLMEngineClient:
def dead_error(self) -> BaseException:
return ENGINE_DEAD_ERROR(self._errored_with)
@overload # DEPRECATED
def generate(
self,
inputs: PromptInputs,
*,
inputs: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
@overload
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def generate(
self,
prompt: Optional[PromptType] = None,
sampling_params: Optional[SamplingParams] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
......@@ -383,8 +429,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
......@@ -392,18 +437,58 @@ class MQLLMEngineClient:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
return self._process_request(inputs, sampling_params, request_id,
if inputs is not None:
prompt = inputs
assert (prompt is not None and sampling_params is not None
and request_id is not None)
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
prompt_adapter_request, priority)
@overload # DEPRECATED
def encode(
self,
*,
inputs: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@overload
def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def encode(
self,
prompt: Optional[PromptType] = None,
pooling_params: Optional[PoolingParams] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
......@@ -412,8 +497,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
......@@ -424,17 +508,29 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
lora_request, trace_headers)
if inputs is not None:
prompt = inputs
assert (prompt is not None and pooling_params is not None
and request_id is not None)
return cast(
AsyncGenerator[EmbeddingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))
async def _process_request(
self,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
......@@ -443,6 +539,20 @@ class MQLLMEngineClient:
if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with)
# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
)
# 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()
......@@ -462,12 +572,14 @@ class MQLLMEngineClient:
request_bytes = pickle.dumps(
RPCProcessRequest(
inputs=inputs,
prompt=prompt,
params=params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
prompt_adapter_request=prompt_adapter_request,
priority=priority,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
......
......@@ -73,11 +73,9 @@ class MQLLMEngine:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs = True
kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args,
**kwargs,
use_cached_outputs=use_cached_outputs)
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests
self.use_async_sockets = use_async_sockets
......@@ -130,6 +128,9 @@ class MQLLMEngine:
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
# Setup plugins for each process
from vllm.plugins import load_general_plugins
load_general_plugins()
engine_config = engine_args.create_engine_config()
......@@ -278,11 +279,12 @@ class MQLLMEngine:
try:
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request)
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
......
import functools
from typing import Callable, List
from typing import Callable, List, cast
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
......@@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
......@@ -57,11 +59,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""
for output in outputs:
# Concatenate single-step prompt logprob processing results.
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod
@functools.lru_cache()
def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
logger.warning(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
......@@ -97,6 +102,19 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
seq_id = seq.seq_id
# This method is defined in the more generic
# SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert all([
isinstance(output, CompletionSequenceGroupOutput)
for output in outputs
])
compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
assert all([
seq_id == output.samples[0].parent_seq_id
for output in compl_outputs
])
if is_async:
# Async case: We process tokens one by one. Here, we know the token
......@@ -108,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples = [output.samples[0] for output in outputs]
samples = [output.samples[0] for output in compl_outputs]
# entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens).
......@@ -145,7 +163,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode
......@@ -159,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id, output_logprob in zip(output_token_ids,
......@@ -171,6 +188,13 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs=output_logprob,
)
if is_prefill_sampled_token:
is_prefill_sampled_token = False
else:
# Update num_computed_tokens iff the sampled token is not from
# a prefill step.
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params)
if seq.is_finished():
......
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
......@@ -6,9 +6,9 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.sequence import (CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
......@@ -17,7 +17,7 @@ logger = init_logger(__name__)
def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: SequenceGroupOutput) -> None:
output: CompletionSequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
......@@ -107,13 +107,14 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
assert isinstance(output, CompletionSequenceGroupOutput)
single_step_process_prompt_logprob(self, seq_group, output)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.best_of == 1 and not sampling_params.use_beam_search:
if sampling_params.n == 1:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
......@@ -142,7 +143,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
......@@ -197,106 +197,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
lora_req=seq_group.lora_request,
)
# Non-beam search case
if not sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = sampling_params.best_of
length_penalty = sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
sampling_params.early_stopping, sampling_params,
best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
......@@ -305,61 +208,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score
return
......@@ -57,7 +57,7 @@ class StopChecker:
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
......@@ -92,7 +92,7 @@ class StopChecker:
Returns the stop string if matched or else None.
"""
if not new_char_count:
if not new_char_count or not sampling_params.stop:
return None
for stop_str in sampling_params.stop:
......
from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from typing import cast
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import PoolerOutput, SequenceGroupOutput
from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput
def create_output_by_sequence_group(
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
outputs: GenericSequence[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[SequenceGroupOutput]] = [
output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in outputs:
sequence_group_output: CompletionSequenceGroupOutput
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
return output_by_sequence_group
# Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return cast(List[List[SequenceGroupOutput]], output_by_sequence_group)
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
runtime_checkable)
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid
logger = init_logger(__name__)
@runtime_checkable
class EngineClient(Protocol):
class EngineClient(ABC):
"""Protocol class for Clients to Engine"""
@property
@abstractmethod
def is_running(self) -> bool:
...
@property
@abstractmethod
def is_stopped(self) -> bool:
...
@property
@abstractmethod
def errored(self) -> bool:
...
@property
@abstractmethod
def dead_error(self) -> BaseException:
...
@abstractmethod
def generate(
self,
inputs: PromptInputs,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request"""
"""Generate outputs for a request."""
...
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []
for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)
yield beam_search_output
@abstractmethod
def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...
@abstractmethod
async def abort(self, request_id: str) -> None:
"""Abort a request.
......@@ -63,14 +172,17 @@ class EngineClient(Protocol):
request_id: The unique id of the request.
"""
@abstractmethod
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
...
@abstractmethod
async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine."""
@abstractmethod
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
......@@ -78,9 +190,11 @@ class EngineClient(Protocol):
"""Get the appropriate tokenizer for the request"""
...
@abstractmethod
async def is_tracing_enabled(self) -> bool:
...
@abstractmethod
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
......@@ -88,14 +202,17 @@ class EngineClient(Protocol):
) -> None:
...
@abstractmethod
async def check_health(self) -> None:
"""Raise if unhealthy"""
...
@abstractmethod
async def start_profile(self) -> None:
"""Start profiling the engine"""
...
@abstractmethod
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...
......@@ -157,22 +157,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
raise TypeError(f"Unknown model type: {model_type}")
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}")
raise TypeError(f"Unknown {modality} model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
......@@ -303,6 +305,28 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder(placeholder)
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
return
elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError(
"the supplied chat template path doesn't exist")
elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n"
if not any(c in chat_template
for c in JINJA_CHARS) and not Path(chat_template).exists():
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!")
else:
raise TypeError(
f"{type(chat_template)} is not a valid chat template type")
def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
......@@ -542,6 +566,14 @@ def apply_mistral_chat_template(
if chat_template is not None:
logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
return tokenizer.apply_chat_template(
messages=messages,
......
import itertools
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)
from tqdm import tqdm
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
......@@ -32,37 +34,6 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None
@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]
class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
......@@ -179,15 +150,7 @@ class LLM:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = (
"image_token_id",
"image_feature_size",
"image_input_shape",
"image_input_type",
)
if any(k in kwargs for k in removed_vision_keys):
raise TypeError(
"There is no need to pass vision-related arguments anymore.")
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
......@@ -293,8 +256,8 @@ class LLM:
@overload
def generate(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
......@@ -304,14 +267,13 @@ class LLM:
...
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.",
additional_message="Please use the 'prompts' parameter instead.",
)
def generate(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
......@@ -330,7 +292,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inputs: A list of inputs to generate completions for.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
......@@ -358,12 +322,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
......@@ -378,7 +343,7 @@ class LLM:
sampling_params = SamplingParams()
self._validate_and_add_requests(
inputs=inputs,
prompts=parsed_prompts,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
......@@ -391,9 +356,7 @@ class LLM:
def beam_search(
self,
prompts: List[Union[str, List[int]]],
beam_width: int,
max_tokens: int,
ignore_eos: bool = False,
params: BeamSearchParams,
) -> List[BeamSearchOutput]:
"""
Generate sequences using beam search.
......@@ -401,20 +364,30 @@ class LLM:
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
params: The beam search parameters.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=0.0)
temperature=temperature)
instances: List[BeamSearchInstance] = []
for prompt in prompts:
......@@ -469,7 +442,7 @@ class LLM:
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob,
key=sort_beams_key,
reverse=True)
instance.beams = sorted_beams[:beam_width]
......@@ -477,7 +450,7 @@ class LLM:
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob,
key=sort_beams_key,
reverse=True)
best_beams = sorted_completed[:beam_width]
......@@ -497,7 +470,9 @@ class LLM:
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> List[RequestOutput]:
"""
Generate responses for a chat conversation.
......@@ -524,6 +499,11 @@ class LLM:
If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
A list of ``RequestOutput`` objects containing the generated
......@@ -534,10 +514,13 @@ class LLM:
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is List[List[...]]
list_of_messages = messages
list_of_messages = cast(List[List[ChatCompletionMessageParam]],
messages)
else:
# messages is List[...]
list_of_messages = [messages]
list_of_messages = [
cast(List[ChatCompletionMessageParam], messages)
]
prompts: List[Union[TokensPrompt, TextPrompt]] = []
......@@ -545,6 +528,9 @@ class LLM:
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer)
......@@ -555,6 +541,7 @@ class LLM:
messages=msgs,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
else:
......@@ -563,6 +550,7 @@ class LLM:
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
......@@ -575,6 +563,9 @@ class LLM:
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs
prompts.append(prompt)
return self.generate(
......@@ -648,8 +639,8 @@ class LLM:
@overload
def encode(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
......@@ -659,14 +650,13 @@ class LLM:
...
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.",
additional_message="Please use the 'prompts' parameter instead.",
)
def encode(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
......@@ -682,9 +672,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
......@@ -707,19 +697,20 @@ class LLM:
)
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
inputs=inputs,
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
......@@ -763,9 +754,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
inputs: List[PromptInputs] = []
parsed_prompts: List[PromptType] = []
for i in range(num_requests):
item: PromptInputs
item: PromptType
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
......@@ -774,13 +765,13 @@ class LLM:
else:
raise AssertionError
inputs.append(item)
parsed_prompts.append(item)
return inputs
return parsed_prompts
def _validate_and_add_requests(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
......@@ -788,11 +779,19 @@ class LLM:
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
) -> None:
if isinstance(inputs, (str, dict)):
if guided_options is not None:
warnings.warn(
"guided_options_request is deprecated, use "
"SamplingParams.guided_decoding instead",
DeprecationWarning,
stacklevel=2,
)
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
inputs = [inputs]
prompts = [prompts]
num_requests = len(inputs)
num_requests = len(prompts)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
......@@ -803,15 +802,15 @@ class LLM:
for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_processor(sp, guided_options)
self._add_guided_params(sp, guided_options)
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
for i, prompt in enumerate(prompts):
self._add_request(
request_inputs,
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
......@@ -821,7 +820,7 @@ class LLM:
def _add_request(
self,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
......@@ -830,29 +829,32 @@ class LLM:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
inputs,
prompt,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
def _add_guided_processor(
def _add_guided_params(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer())
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
if guided_options is None:
return params
if params.guided_decoding is not None:
raise ValueError("Cannot set both guided_options_request and"
"params.guided_decoding.")
params.guided_decoding = GuidedDecodingParams(
json=guided_options.guided_json,
regex=guided_options.guided_regex,
choice=guided_options.guided_choice,
grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern)
return params
def _run_engine(
......
......@@ -4,7 +4,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__)
......@@ -21,7 +21,8 @@ class RequestLogger:
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
params: Optional[Union[SamplingParams, PoolingParams]],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
......
......@@ -31,7 +31,8 @@ from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
......@@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
......@@ -526,8 +528,20 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
temp_socket.bind(("", args.port))
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", args.port))
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
......@@ -541,8 +555,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
temp_socket.close()
shutdown_task = await serve_http(
app,
host=args.host,
......@@ -553,6 +565,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
fd=sock.fileno(),
**uvicorn_kwargs,
)
......@@ -567,5 +580,6 @@ if __name__ == "__main__":
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))
......@@ -10,8 +10,10 @@ import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser
......@@ -190,16 +192,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes"],
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
"--tool-parser-plugin",
default=None,
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
parser.add_argument(
"--tool-parser-plugin",
type=str,
default="",
help=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
......@@ -219,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
return parser
def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)
# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
......
......@@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
# torch is mocked during docs generation,
......@@ -186,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
ignore_eos: bool = False
......@@ -211,6 +208,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field(
default=False,
description=(
......@@ -272,13 +278,33 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-chat-completion-extra-params
def to_sampling_params(
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
......@@ -287,14 +313,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=None,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
guided_json_object = None
if (self.response_format is not None
and self.response_format.type == "json_object"):
guided_json_object = True
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional(
n=self.n,
......@@ -314,17 +345,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)
guided_decoding=guided_decoding,
logit_bias=self.logit_bias)
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
# user has chosen to not use any tool
if self.tool_choice == "none" or self.tools is None:
return None
# user has chosen to use a named tool
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = self.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in self.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
return tool.parameters
return None
@model_validator(mode="before")
@classmethod
......@@ -386,7 +432,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# if "tool_choice" is not specified but tools are provided,
# default to "auto" tool_choice
if "tool_choice" not in data and "tools" in data:
if "tool_choice" not in data and data.get("tools"):
data["tool_choice"] = "auto"
# if "tool_choice" is specified -- validation
......@@ -431,6 +477,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
" of the specified `tools`")
return data
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
......@@ -460,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel):
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
ignore_eos: bool = False
......@@ -516,13 +570,33 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-completion-extra-params
def to_sampling_params(
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
......@@ -533,13 +607,19 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
guided_json_object = None
if (self.response_format is not None
and self.response_format.type == "json_object"):
guided_json_object = True
guided_decoding = GuidedDecodingParams.from_optional(
json=self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional(
n=self.n,
......@@ -558,18 +638,16 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
@model_validator(mode="before")
@classmethod
......@@ -619,12 +697,23 @@ class EmbeddingRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-embedding-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
......@@ -862,8 +951,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True)
continue_final_message: bool = Field(default=False)
add_special_tokens: bool = Field(default=False)
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
......
......@@ -29,12 +29,11 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
MistralToolParser,
ToolParser)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
......@@ -81,13 +80,13 @@ class OpenAIServingChat(OpenAIServing):
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
if tool_parser == "mistral":
self.tool_parser = MistralToolParser
elif tool_parser == "hermes":
self.tool_parser = Hermes2ProToolParser
else:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
tool_parser)
except Exception as e:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e
async def create_chat_completion(
self,
......@@ -137,6 +136,7 @@ class OpenAIServingChat(OpenAIServing):
messages=request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
......@@ -147,18 +147,19 @@ class OpenAIServingChat(OpenAIServing):
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
logger.exception("Error in applying chat template from request")
return self.create_error_response(str(e))
try:
mm_data = await mm_data_future
except Exception as e:
logger.error("Error in loading multi-modal data: %s", e)
logger.exception("Error in loading multi-modal data")
return self.create_error_response(str(e))
# validation for OpenAI tools
......@@ -182,8 +183,9 @@ class OpenAIServingChat(OpenAIServing):
raw_request.state.request_metadata = request_metadata
try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
if self.enable_auto_tools and self.tool_parser:
request = self.tool_parser(tokenizer).adjust_request(
request=request)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
......@@ -202,11 +204,15 @@ class OpenAIServingChat(OpenAIServing):
assert prompt_inputs is not None
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
self._log_inputs(request_id,
prompt_inputs,
......@@ -228,14 +234,22 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
)
else:
result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......@@ -281,12 +295,8 @@ class OpenAIServingChat(OpenAIServing):
num_choices = 1 if request.n is None else request.n
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
......@@ -305,6 +315,29 @@ class OpenAIServingChat(OpenAIServing):
else:
previous_texts, all_previous_token_ids = None, None
# Prepare the tool parser if it's needed
try:
if tool_choice_auto and self.tool_parser:
tool_parsers: List[Optional[ToolParser]] = [
self.tool_parser(tokenizer)
] * num_choices
else:
tool_parsers = [None] * num_choices
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False
try:
async for res in result_generator:
if res.prompt_token_ids is not None:
......@@ -323,7 +356,6 @@ class OpenAIServingChat(OpenAIServing):
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
......@@ -339,26 +371,19 @@ class OpenAIServingChat(OpenAIServing):
choices=[choice_data],
model=model_name)
# if usage should be included
if (request.stream_options
and request.stream_options.include_usage):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
chunk.usage = None
# if continuous usage stats are requested, add it
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the
# last message
if request.echo:
if request.echo or request.continue_final_message:
last_msg_content: str = ""
if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role:
......@@ -379,17 +404,11 @@ class OpenAIServingChat(OpenAIServing):
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
if (request.stream_options.
continuous_usage_stats):
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
data = chunk.model_dump_json(
exclude_unset=True)
......@@ -398,6 +417,7 @@ class OpenAIServingChat(OpenAIServing):
for output in res.outputs:
i = output.index
tool_parser = tool_parsers[i]
if finish_reason_sent[i]:
continue
......@@ -415,6 +435,12 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None
delta_text = output.text
if not delta_text and not output.token_ids and \
not previous_num_tokens[i]:
# Chunked prefill case, don't return empty chunks
continue
delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice
......@@ -445,7 +471,8 @@ class OpenAIServingChat(OpenAIServing):
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids))
delta_token_ids=output.token_ids,
request=request))
# update the previous values for the next iteration
previous_texts[i] = current_text
......@@ -467,36 +494,11 @@ class OpenAIServingChat(OpenAIServing):
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# if the model is finished generating
else:
......@@ -504,10 +506,12 @@ class OpenAIServingChat(OpenAIServing):
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using guided decoding
auto_tools_called = False
if tool_parser:
index = len(
tool_parser.prev_tool_call_arr) - 1 if len(
tool_parser.prev_tool_call_arr) > 0 else 0
auto_tools_called = len(
tool_parser.prev_tool_call_arr) > 0
index = len(tool_parser.prev_tool_call_arr
) - 1 if auto_tools_called else 0
else:
index = 0
......@@ -542,38 +546,34 @@ class OpenAIServingChat(OpenAIServing):
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason
if not (tool_parser
and len(tool_parser.prev_tool_call_arr))
else "tool_calls",
if not auto_tools_called else "tool_calls",
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if include_continuous_usage:
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
if include_usage:
completion_tokens = sum(previous_num_tokens)
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
......@@ -600,7 +600,7 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e)
logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
......@@ -646,8 +646,10 @@ class OpenAIServingChat(OpenAIServing):
else:
logprobs = None
# by default, tools are not used.
tools_called = False
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
......@@ -669,7 +671,6 @@ class OpenAIServingChat(OpenAIServing):
name=request.tool_choice.function.name,
arguments=output.text))
])
tools_called = True
# if the request doesn't use tool choice
# OR specifies to not use a tool
......@@ -683,9 +684,18 @@ class OpenAIServingChat(OpenAIServing):
or request.tool_choice is None) and self.enable_auto_tools \
and self.tool_parser:
tool_parser = self.tool_parser(tokenizer)
tool_call_info = tool_parser.extract_tool_calls(output.text)
tools_called = tool_call_info.tools_called
try:
tool_parser = self.tool_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
return self.create_error_response(str(e))
tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,
content=tool_call_info.content,
......@@ -708,12 +718,12 @@ class OpenAIServingChat(OpenAIServing):
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls" if tools_called else
finish_reason="tool_calls" if auto_tools_called else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason)
choices.append(choice_data)
if request.echo:
if request.echo or request.continue_final_message:
last_msg_content = ""
if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role:
......@@ -726,6 +736,8 @@ class OpenAIServingChat(OpenAIServing):
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
......
......@@ -28,6 +28,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
......@@ -110,8 +111,6 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
......@@ -122,11 +121,15 @@ class OpenAIServingCompletion(OpenAIServing):
))
for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
request_id_item = f"{request_id}-{i}"
......@@ -145,14 +148,25 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers):
log_tracing_disabled_warning()
generator = self.engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"],
request_id_item,
sampling_params,
)
else:
generator = self.engine_client.generate(
{
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
......@@ -260,8 +274,6 @@ class OpenAIServingCompletion(OpenAIServing):
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
......@@ -293,6 +305,11 @@ class OpenAIServingCompletion(OpenAIServing):
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
if not delta_text and not delta_token_ids \
and not previous_num_tokens[i]:
# Chunked prefill case, don't return empty chunks
continue
if request.logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment