Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
...@@ -2,19 +2,9 @@ ...@@ -2,19 +2,9 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json from typing import Sequence
import os
from typing import Dict, Optional, Sequence
import torch import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger
from .parallel_state import get_cpu_world_group, get_local_rank
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator): def ensure_divisibility(numerator, denominator):
...@@ -56,81 +46,3 @@ def split_tensor_along_last_dim( ...@@ -56,81 +46,3 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list) return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list return tensor_list
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c).cpu().item()
# why do we need this cache?
# 1. we can have runtime checks for P2P access, where every process checks
# P2P access to all other GPUs. Unfortunately, the test might cost many
# (world_size * world_size) cuda context, and reduce the memory available
# for the model. see https://github.com/vllm-project/vllm/issues/3821
# 2. alternatively, we can have a p2p map that is generated by the master
# process and broadcasted to all other processes. This still requires
# #world_size of cuda context, belonging to the master process, on each GPU.
# 3. we can have a cache file, that records the p2p access status. The first
# time the master process checks the p2p access, it will generate the cache
# file, at the cost of #world_size of cuda context. Later on, all processes
# can read the cache file to check the p2p access status without any cost of
# additional cuda context.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(i: int, j: int) -> bool:
"""Check if GPU i can access GPU j."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"]
is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
path = os.path.expanduser(
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)
os.makedirs(os.path.dirname(path), exist_ok=True)
if (not is_distributed or get_local_rank() == 0) \
and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache for in %s", path)
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
_i, _j) and _can_actually_p2p(_i, _j)
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group)
logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"]
import argparse import argparse
import dataclasses import dataclasses
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
...@@ -34,11 +35,13 @@ class EngineArgs: ...@@ -34,11 +35,13 @@ class EngineArgs:
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
distributed_executor_backend: Optional[str] = None
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
block_size: int = 16 block_size: int = 16
enable_prefix_caching: bool = False enable_prefix_caching: bool = False
disable_sliding_window: bool = False
use_v2_block_manager: bool = False use_v2_block_manager: bool = False
swap_space: int = 4 # GiB swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
...@@ -48,6 +51,7 @@ class EngineArgs: ...@@ -48,6 +51,7 @@ class EngineArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
...@@ -62,6 +66,7 @@ class EngineArgs: ...@@ -62,6 +66,7 @@ class EngineArgs:
max_lora_rank: int = 16 max_lora_rank: int = 16
fully_sharded_loras: bool = False fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype = 'auto' lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: str = 'auto'
...@@ -83,6 +88,7 @@ class EngineArgs: ...@@ -83,6 +88,7 @@ class EngineArgs:
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None
...@@ -166,8 +172,8 @@ class EngineArgs: ...@@ -166,8 +172,8 @@ class EngineArgs:
'* "dummy" will initialize the weights with random values, ' '* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n' 'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from ' '* "tensorizer" will load the weights using tensorizer from '
'CoreWeave which assumes tensorizer_uri is set to the location of ' 'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'the serialized weights.') 'section for more information.\n')
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
...@@ -186,12 +192,11 @@ class EngineArgs: ...@@ -186,12 +192,11 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--kv-cache-dtype', '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8'], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype, default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=nullable_str, type=nullable_str,
...@@ -220,10 +225,17 @@ class EngineArgs: ...@@ -220,10 +225,17 @@ class EngineArgs:
' Can be overridden per request via guided_decoding_backend' ' Can be overridden per request via guided_decoding_backend'
' parameter.') ' parameter.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', parser.add_argument(
action='store_true', '--distributed-executor-backend',
help='Use Ray for distributed serving, will be ' choices=['ray', 'mp'],
'automatically set when using more than 1 GPU.') default=EngineArgs.distributed_executor_backend,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--worker-use-ray',
action='store_true',
help='Deprecated, use --distributed-executor-backend=ray.')
parser.add_argument('--pipeline-parallel-size', parser.add_argument('--pipeline-parallel-size',
'-pp', '-pp',
type=int, type=int,
...@@ -256,6 +268,10 @@ class EngineArgs: ...@@ -256,6 +268,10 @@ class EngineArgs:
parser.add_argument('--enable-prefix-caching', parser.add_argument('--enable-prefix-caching',
action='store_true', action='store_true',
help='Enables automatic prefix caching.') help='Enables automatic prefix caching.')
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
'capping to sliding window size')
parser.add_argument('--use-v2-block-manager', parser.add_argument('--use-v2-block-manager',
action='store_true', action='store_true',
help='Use BlockSpaceMangerV2.') help='Use BlockSpaceMangerV2.')
...@@ -320,6 +336,11 @@ class EngineArgs: ...@@ -320,6 +336,11 @@ class EngineArgs:
'None, we assume the model weights are not ' 'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data ' 'quantized and use `dtype` to determine the data '
'type of the weights.') 'type of the weights.')
parser.add_argument('--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}')
parser.add_argument('--enforce-eager', parser.add_argument('--enforce-eager',
action='store_true', action='store_true',
help='Always use eager-mode PyTorch. If False, ' help='Always use eager-mode PyTorch. If False, '
...@@ -331,9 +352,9 @@ class EngineArgs: ...@@ -331,9 +352,9 @@ class EngineArgs:
help='Maximum context length covered by CUDA ' help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length ' 'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. ' 'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead' '(DEPRECATED. Use --max-seq-len-to-capture instead'
')') ')')
parser.add_argument('--max-seq_len-to-capture', parser.add_argument('--max-seq-len-to-capture',
type=int, type=int,
default=EngineArgs.max_seq_len_to_capture, default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA ' help='Maximum sequence length covered by CUDA '
...@@ -388,6 +409,17 @@ class EngineArgs: ...@@ -388,6 +409,17 @@ class EngineArgs:
choices=['auto', 'float16', 'bfloat16', 'float32'], choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to ' help=('Data type for LoRA. If auto, will default to '
'base model dtype.')) 'base model dtype.'))
parser.add_argument(
'--long-lora-scaling-factors',
type=nullable_str,
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can '
'be different from base model scaling factor '
'- see eg. Long LoRA) to allow for multiple '
'LoRA adapters trained with those scaling '
'factors to be used at the same time. If not '
'specified, only adapters trained with the '
'base model scaling factor are allowed.'))
parser.add_argument( parser.add_argument(
'--max-cpu-loras', '--max-cpu-loras',
type=int, type=int,
...@@ -467,6 +499,13 @@ class EngineArgs: ...@@ -467,6 +499,13 @@ class EngineArgs:
'draft model. Sequences over this length will skip ' 'draft model. Sequences over this length will skip '
'speculation.') 'speculation.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument( parser.add_argument(
'--ngram-prompt-lookup-max', '--ngram-prompt-lookup-max',
type=int, type=int,
...@@ -508,7 +547,7 @@ class EngineArgs: ...@@ -508,7 +547,7 @@ class EngineArgs:
return parser return parser
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def from_cli_args(cls, args: argparse.Namespace):
# Get the list of attributes of this dataclass. # Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments. # Set the attributes from the parsed arguments.
...@@ -520,10 +559,11 @@ class EngineArgs: ...@@ -520,10 +559,11 @@ class EngineArgs:
model_config = ModelConfig( model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode, self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.dtype, self.seed, self.revision, self.trust_remote_code, self.dtype, self.seed, self.revision,
self.code_revision, self.tokenizer_revision, self.max_model_len, self.code_revision, self.rope_scaling, self.tokenizer_revision,
self.quantization, self.quantization_param_path, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture, self.quantization_param_path, self.enforce_eager,
self.max_seq_len_to_capture, self.max_logprobs, self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.max_logprobs, self.disable_sliding_window,
self.skip_tokenizer_init, self.served_model_name) self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
...@@ -532,14 +572,18 @@ class EngineArgs: ...@@ -532,14 +572,18 @@ class EngineArgs:
model_config.get_sliding_window(), model_config.get_sliding_window(),
self.enable_prefix_caching) self.enable_prefix_caching)
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
self.pipeline_parallel_size, self.tensor_parallel_size, self.pipeline_parallel_size,
self.worker_use_ray, self.max_parallel_loading_workers, self.tensor_parallel_size,
self.worker_use_ray,
self.max_parallel_loading_workers,
self.disable_custom_all_reduce, self.disable_custom_all_reduce,
TokenizerPoolConfig.create_config( TokenizerPoolConfig.create_config(
self.tokenizer_pool_size, self.tokenizer_pool_size,
self.tokenizer_pool_type, self.tokenizer_pool_type,
self.tokenizer_pool_extra_config, self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight) ),
self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend)
speculative_config = SpeculativeConfig.maybe_create_spec_config( speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config, target_model_config=model_config,
...@@ -547,6 +591,8 @@ class EngineArgs: ...@@ -547,6 +591,8 @@ class EngineArgs:
target_dtype=self.dtype, target_dtype=self.dtype,
speculative_model=self.speculative_model, speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens, num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len, speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager, use_v2_block_manager=self.use_v2_block_manager,
...@@ -564,12 +610,14 @@ class EngineArgs: ...@@ -564,12 +610,14 @@ class EngineArgs:
speculative_config.num_lookahead_slots), speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor, delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
) )
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras, fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size, lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype, lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None and self.max_cpu_loras > 0 else None) if self.enable_lora else None
...@@ -599,6 +647,13 @@ class EngineArgs: ...@@ -599,6 +647,13 @@ class EngineArgs:
decoding_config = DecodingConfig( decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend) guided_decoding_backend=self.guided_decoding_backend)
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled
and not scheduler_config.use_v2_block_manager):
raise ValueError(
"Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window.")
return EngineConfig(model_config=model_config, return EngineConfig(model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
......
import asyncio import asyncio
import time import time
from functools import partial from functools import partial
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Optional, Set, Tuple, Type, Union) Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -12,11 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs ...@@ -12,11 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -47,15 +49,16 @@ def _raise_exception_on_finish( ...@@ -47,15 +49,16 @@ def _raise_exception_on_finish(
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs for a request that can be """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
iterated over asynchronously.""" that can be iterated over asynchronously."""
def __init__(self, request_id: str) -> None: def __init__(self, request_id: str) -> None:
self.request_id = request_id self.request_id = request_id
self._queue: asyncio.Queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None: def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished: if self._finished:
return return
self._queue.put_nowait(item) self._queue.put_nowait(item)
...@@ -71,7 +74,7 @@ class AsyncStream: ...@@ -71,7 +74,7 @@ class AsyncStream:
def __aiter__(self): def __aiter__(self):
return self return self
async def __anext__(self) -> RequestOutput: async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
result = await self._queue.get() result = await self._queue.get()
if isinstance(result, Exception): if isinstance(result, Exception):
raise result raise result
...@@ -108,7 +111,8 @@ class RequestTracker: ...@@ -108,7 +111,8 @@ class RequestTracker:
self.abort_request(rid) self.abort_request(rid)
def process_request_output(self, def process_request_output(self,
request_output: RequestOutput, request_output: Union[RequestOutput,
EmbeddingRequestOutput],
*, *,
verbose: bool = False) -> None: verbose: bool = False) -> None:
"""Process a request output from the engine.""" """Process a request output from the engine."""
...@@ -196,7 +200,8 @@ class RequestTracker: ...@@ -196,7 +200,8 @@ class RequestTracker:
class _AsyncLLMEngine(LLMEngine): class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods.""" """Extension of LLMEngine to add async methods."""
async def step_async(self) -> List[RequestOutput]: async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
...@@ -230,66 +235,77 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -230,66 +235,77 @@ class _AsyncLLMEngine(LLMEngine):
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()
return request_outputs return request_outputs
async def encode_request_async( async def process_model_inputs_async(
self, self,
request_id: str, # pylint: disable=unused-argument request_id: str,
prompt: Optional[str], inputs: PromptInputs,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
): ) -> LLMInputs:
if prompt_token_ids is None: if isinstance(inputs, str):
assert prompt is not None inputs = {"prompt": inputs}
prompt_token_ids = await self.tokenizer.encode_async(
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = await tokenizer.encode_async(
request_id=request_id, request_id=request_id,
prompt=prompt, prompt=inputs["prompt"],
lora_request=lora_request) lora_request=lora_request)
return prompt_token_ids else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
sampling_params: SamplingParams, params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request)
self._add_processed_request(
request_id=request_id, request_id=request_id,
prompt=prompt, processed_inputs=processed_inputs,
prompt_token_ids=prompt_token_ids, params=params,
lora_request=lora_request) arrival_time=arrival_time,
lora_request=lora_request,
return self.add_request(request_id, )
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
self.model_executor.check_health() self.model_executor.check_health()
class AsyncLLMEngine: class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine. """An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMEngine to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`. This class is used to wrap the :class:`LLMEngine` class to make it
asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The :class:`LLMEngine` is kicked by the
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the :class:`LLMEngine` to the caller.
Args: Args:
worker_use_ray: Whether to use Ray for model workers. Required for worker_use_ray: Whether to use Ray for model workers. Required for
...@@ -303,8 +319,8 @@ class AsyncLLMEngine: ...@@ -303,8 +319,8 @@ class AsyncLLMEngine:
being printed in log. being printed in log.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
*args: Arguments for LLMEngine. *args: Arguments for :class:`LLMEngine`.
*kwargs: Arguments for LLMEngine. **kwargs: Arguments for :class:`LLMEngine`.
""" """
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
...@@ -327,7 +343,7 @@ class AsyncLLMEngine: ...@@ -327,7 +343,7 @@ class AsyncLLMEngine:
# We need to keep a reference to unshielded # We need to keep a reference to unshielded
# task as well to prevent it from being garbage # task as well to prevent it from being garbage
# collected # collected
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None self._background_loop_unshielded: Optional[asyncio.Task] = None
self.start_engine_loop = start_engine_loop self.start_engine_loop = start_engine_loop
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
...@@ -344,27 +360,31 @@ class AsyncLLMEngine: ...@@ -344,27 +360,31 @@ class AsyncLLMEngine:
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
assert not engine_config.parallel_config.worker_use_ray, ( assert distributed_executor_backend is None, (
"Ray is not supported with the CPU backend.") "Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray: elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config) initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else: else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync executor_class = GPUExecutorAsync
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
engine_config.parallel_config.worker_use_ray, distributed_executor_backend == "ray",
engine_args.engine_use_ray, engine_args.engine_use_ray,
**engine_config.to_dict(), **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
...@@ -510,27 +530,31 @@ class AsyncLLMEngine: ...@@ -510,27 +530,31 @@ class AsyncLLMEngine:
async def add_request( async def add_request(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
sampling_params: SamplingParams, params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
shortened_prompt = prompt if isinstance(inputs, str):
shortened_token_ids = prompt_token_ids shortened_prompt = inputs
if self.max_log_len is not None: shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None: if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len] shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None: if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self. shortened_token_ids = shortened_token_ids[:max_log_len]
max_log_len]
logger.info( logger.info(
"Received request %s: prompt: %r, " "Received request %s: prompt: %r, "
"sampling_params: %s, prompt_token_ids: %s, " "params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, "lora_request: %s.", request_id, shortened_prompt, params,
sampling_params, shortened_token_ids, lora_request) shortened_token_ids, lora_request)
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
...@@ -546,39 +570,33 @@ class AsyncLLMEngine: ...@@ -546,39 +570,33 @@ class AsyncLLMEngine:
arrival_time = time.time() arrival_time = time.time()
if self.engine_use_ray: if self.engine_use_ray:
prompt_token_ids = await ( processed_inputs = await self.engine.process_model_inputs_async \
self.engine.encode_request_async.remote( # type: ignore .remote( # type: ignore
request_id=request_id, request_id=request_id,
prompt=prompt, inputs=inputs,
prompt_token_ids=prompt_token_ids, lora_request=lora_request)
lora_request=lora_request))
else: else:
prompt_token_ids = await self.engine.encode_request_async( processed_inputs = await self.engine.process_model_inputs_async(
request_id=request_id, request_id=request_id,
prompt=prompt, inputs=inputs,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request) lora_request=lora_request)
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
prompt=prompt, inputs=processed_inputs,
sampling_params=sampling_params, params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data,
) )
return stream return stream
async def generate( async def generate(
self, self,
prompt: Optional[str], inputs: PromptInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -587,18 +605,16 @@ class AsyncLLMEngine: ...@@ -587,18 +605,16 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt string. Can be None if prompt_token_ids is inputs: The inputs to the LLM. See
provided. :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine for the The output `RequestOutput` objects from the LLMEngine
request. for the request.
Details: Details:
- If the engine is not running, start the background loop, - If the engine is not running, start the background loop,
...@@ -643,25 +659,112 @@ class AsyncLLMEngine: ...@@ -643,25 +659,112 @@ class AsyncLLMEngine:
>>> # Process and return the final output >>> # Process and return the final output
>>> ... >>> ...
""" """
# Preprocess the request. async for output in self._process_request(
arrival_time = time.time()
try:
stream = await self.add_request(
request_id, request_id,
prompt, inputs,
sampling_params, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data, ):
) yield LLMEngine.validate_output(output, RequestOutput)
async def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "input": "What is LLM?",
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.encode(
>>> example_input["input"],
>>> PoolingParams(),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
async for output in self._process_request(
request_id,
inputs,
pooling_params,
lora_request=lora_request,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def _process_request(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time = time.time()
stream = await self.add_request(
request_id,
inputs,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
try:
async for request_output in stream: async for request_output in stream:
yield request_output yield request_output
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id) self._abort(request_id)
raise e raise e
......
import time import time
from typing import Iterable, List, Optional, Type, Union from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer from transformers import GenerationConfig, PreTrainedTokenizer
...@@ -18,12 +21,16 @@ from vllm.engine.output_processor.stop_checker import StopChecker ...@@ -18,12 +21,16 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
...@@ -47,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig): ...@@ -47,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
return {} return {}
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
...@@ -57,11 +67,11 @@ class LLMEngine: ...@@ -57,11 +67,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the iteration-level scheduling and efficient memory management to maximize the
serving throughput. serving throughput.
The `LLM` class wraps this class for offline batched inference and the The :class:`~vllm.LLM` class wraps this class for offline batched inference
`AsyncLLMEngine` class wraps this class for online serving. and the :class:`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the The config arguments are derived from :class:`~vllm.EngineArgs`. (See
comprehensive list of arguments, see `EngineArgs`. :ref:`engine_args`)
Args: Args:
model_config: The configuration related to the LLM model. model_config: The configuration related to the LLM model.
...@@ -78,9 +88,60 @@ class LLMEngine: ...@@ -78,9 +88,60 @@ class LLMEngine:
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection usage_context: Specified entry point, used for usage info collection.
""" """
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
"""A flag to toggle whether to validate the type of request output."""
@classmethod
@contextmanager
def enable_output_validation(cls):
cls.DO_VALIDATE_OUTPUT = True
yield
cls.DO_VALIDATE_OUTPUT = False
@classmethod
def validate_output(
cls,
output: object,
output_type: Type[_O],
) -> _O:
do_validate = cls.DO_VALIDATE_OUTPUT
if ((TYPE_CHECKING or do_validate)
and not isinstance(output, output_type)):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT
outputs_: List[_O]
if TYPE_CHECKING or do_validate:
outputs_ = []
for output in outputs:
if not isinstance(output, output_type):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
outputs_.append(output)
else:
outputs_ = outputs
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -101,10 +162,11 @@ class LLMEngine: ...@@ -101,10 +162,11 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, " "model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " "rope_scaling=%r, tokenizer_revision=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " "disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)", "decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__, vllm.__version__,
...@@ -114,6 +176,7 @@ class LLMEngine: ...@@ -114,6 +176,7 @@ class LLMEngine:
model_config.skip_tokenizer_init, model_config.skip_tokenizer_init,
model_config.tokenizer_mode, model_config.tokenizer_mode,
model_config.revision, model_config.revision,
model_config.rope_scaling,
model_config.tokenizer_revision, model_config.tokenizer_revision,
model_config.trust_remote_code, model_config.trust_remote_code,
model_config.dtype, model_config.dtype,
...@@ -146,12 +209,11 @@ class LLMEngine: ...@@ -146,12 +209,11 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup self.tokenizer = self._init_tokenizer()
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
else: else:
self.detokenizer = None
self.tokenizer = None self.tokenizer = None
self.detokenizer = None
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = _load_generation_config_dict(
...@@ -169,7 +231,8 @@ class LLMEngine: ...@@ -169,7 +231,8 @@ class LLMEngine:
load_config=load_config, load_config=load_config,
) )
self._initialize_kv_caches() if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled(): if is_usage_stats_enabled():
...@@ -270,6 +333,8 @@ class LLMEngine: ...@@ -270,6 +333,8 @@ class LLMEngine:
"""Creates an LLM engine from the engine arguments.""" """Creates an LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class. # Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
...@@ -278,13 +343,15 @@ class LLMEngine: ...@@ -278,13 +343,15 @@ class LLMEngine:
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray: elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config) initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor executor_class = RayGPUExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
executor_class = MultiprocessingGPUExecutor
else: else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor executor_class = GPUExecutor
...@@ -308,14 +375,26 @@ class LLMEngine: ...@@ -308,14 +375,26 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None): if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown() model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer": def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None) return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self, def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer": sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request) return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs): def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict( init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer, tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config), enable_lora=bool(self.lora_config),
...@@ -325,8 +404,9 @@ class LLMEngine: ...@@ -325,8 +404,9 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision) revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs) init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs) return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
...@@ -336,29 +416,85 @@ class LLMEngine: ...@@ -336,29 +416,85 @@ class LLMEngine:
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
def encode_request( def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _add_processed_request(
self,
request_id: str,
processed_inputs: LLMInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def process_model_inputs(
self, self,
request_id: str, # pylint: disable=unused-argument request_id: str,
prompt: Optional[str], inputs: PromptInputs,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
): ) -> LLMInputs:
if prompt_token_ids is None: if isinstance(inputs, str):
assert prompt is not None inputs = {"prompt": inputs}
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt, if "prompt_token_ids" not in inputs:
lora_request=lora_request) tokenizer = self.get_tokenizer_group("prompts must be None if "
return prompt_token_ids "skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
sampling_params: SamplingParams, params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -368,14 +504,14 @@ class LLMEngine: ...@@ -368,14 +504,14 @@ class LLMEngine:
Args: Args:
request_id: The unique ID of the request. request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is inputs: The inputs to the LLM. See
provided. :class:`~vllm.inputs.PromptInputs`
sampling_params: The sampling parameters for text generation. for more details about the format of each input.
prompt_token_ids: The token IDs of the prompt. If None, we params: Parameters for sampling or pooling.
use the tokenizer to convert the prompts to token IDs. :class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
multi_modal_data: Multi modal data per request.
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
...@@ -404,6 +540,30 @@ class LLMEngine: ...@@ -404,6 +540,30 @@ class LLMEngine:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if arrival_time is None:
arrival_time = time.time()
processed_inputs = self.process_model_inputs(request_id=request_id,
inputs=inputs,
lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
def _create_sequence_group_with_sampling(
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or ( and sampling_params.logprobs > max_logprobs) or (
...@@ -411,26 +571,6 @@ class LLMEngine: ...@@ -411,26 +571,6 @@ class LLMEngine:
and sampling_params.prompt_logprobs > max_logprobs): and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than " raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.") f"{max_logprobs} logprobs.")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = self.encode_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
...@@ -443,11 +583,32 @@ class LLMEngine: ...@@ -443,11 +583,32 @@ class LLMEngine:
self.generation_config_fields) self.generation_config_fields)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id=request_id,
arrival_time, lora_request, multi_modal_data) seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request)
# Add the sequence group to the scheduler. return seq_group
self.scheduler.add_seq_group(seq_group)
def _create_sequence_group_with_pooling(
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID. """Aborts a request(s) with the given ID.
...@@ -484,13 +645,25 @@ class LLMEngine: ...@@ -484,13 +645,25 @@ class LLMEngine:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()
def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
) -> None:
seq_group.embeddings = outputs[0].embeddings
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _process_model_outputs( def _process_model_outputs(
self, self,
output: List[SamplerOutput], output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup], scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]: ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Apply the model output to the sequences in the scheduled seq groups. """Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client. Returns RequestOutputs that can be returned to the client.
...@@ -501,7 +674,7 @@ class LLMEngine: ...@@ -501,7 +674,7 @@ class LLMEngine:
# Organize outputs by [sequence group][step] instead of # Organize outputs by [sequence group][step] instead of
# [step][sequence group]. # [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group( output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip( for scheduled_seq_group, outputs, seq_group_meta in zip(
...@@ -510,6 +683,9 @@ class LLMEngine: ...@@ -510,6 +683,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
continue
self.output_processor.process_prompt_logprob(seq_group, outputs) self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample: if seq_group_meta.do_sample:
...@@ -519,18 +695,19 @@ class LLMEngine: ...@@ -519,18 +695,19 @@ class LLMEngine:
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
for scheduled_seq_group in scheduled_seq_groups: for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
for seq_group in ignored_seq_groups: for seq_group in ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
return request_outputs return request_outputs
def step(self) -> List[RequestOutput]: def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png .. figure:: https://i.imgur.com/sv2HssD.png
...@@ -570,7 +747,7 @@ class LLMEngine: ...@@ -570,7 +747,7 @@ class LLMEngine:
>>> while True: >>> while True:
>>> if example_inputs: >>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0) >>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params) >>> engine.add_request(str(req_id),prompt,sampling_params)
>>> >>>
>>> # continue the request processing >>> # continue the request processing
>>> request_outputs = engine.step() >>> request_outputs = engine.step()
...@@ -604,6 +781,14 @@ class LLMEngine: ...@@ -604,6 +781,14 @@ class LLMEngine:
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()
return request_outputs return request_outputs
def do_log_stats( def do_log_stats(
...@@ -637,12 +822,15 @@ class LLMEngine: ...@@ -637,12 +822,15 @@ class LLMEngine:
# KV Cache Usage in % # KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() gpu_cache_usage_sys = 0.
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) if num_total_gpu is not None:
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0. cpu_cache_usage_sys = 0.
if num_total_cpu > 0: if num_total_cpu is not None and num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
) )
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
...@@ -652,6 +840,8 @@ class LLMEngine: ...@@ -652,6 +840,8 @@ class LLMEngine:
num_generation_tokens_iter = 0 num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = [] time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = [] time_per_output_tokens_iter: List[float] = []
num_preemption_iter = (0 if scheduler_outputs is None else
scheduler_outputs.preempted)
# Request stats # Request stats
# Latency # Latency
...@@ -716,8 +906,10 @@ class LLMEngine: ...@@ -716,8 +906,10 @@ class LLMEngine:
seq.get_output_len() seq.get_output_len()
for seq in seq_group.get_finished_seqs() for seq in seq_group.get_finished_seqs()
]) ])
best_of_requests.append(seq_group.sampling_params.best_of) if seq_group.sampling_params is not None:
n_requests.append(seq_group.sampling_params.n) best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([ finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status) SequenceStatus.get_finished_reason(seq.status)
for seq in seq_group.get_finished_seqs() for seq in seq_group.get_finished_seqs()
...@@ -743,7 +935,6 @@ class LLMEngine: ...@@ -743,7 +935,6 @@ class LLMEngine:
return Stats( return Stats(
now=now, now=now,
# System stats # System stats
# Scheduler State # Scheduler State
num_running_sys=num_running_sys, num_running_sys=num_running_sys,
...@@ -759,6 +950,7 @@ class LLMEngine: ...@@ -759,6 +950,7 @@ class LLMEngine:
time_to_first_tokens_iter=time_to_first_tokens_iter, time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter, time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics, spec_decode_metrics=spec_decode_metrics,
num_preemption_iter=num_preemption_iter,
# Request stats # Request stats
# Latency # Latency
......
...@@ -61,6 +61,10 @@ class Metrics: ...@@ -61,6 +61,10 @@ class Metrics:
labelnames=labelnames) labelnames=labelnames)
# Iteration stats # Iteration stats
self.counter_num_preemption = Counter(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = Counter( self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
...@@ -181,6 +185,7 @@ class Stats: ...@@ -181,6 +185,7 @@ class Stats:
num_generation_tokens_iter: int num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float] time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float] time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix) # Request stats (should have _requests suffix)
# Latency # Latency
...@@ -244,6 +249,8 @@ class StatLogger: ...@@ -244,6 +249,8 @@ class StatLogger:
stats.cpu_cache_usage_sys) stats.cpu_cache_usage_sys)
# Iteration level data # Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
stats.num_preemption_iter)
self._log_counter(self.metrics.counter_prompt_tokens, self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens_iter) stats.num_prompt_tokens_iter)
self._log_counter(self.metrics.counter_generation_tokens, self._log_counter(self.metrics.counter_generation_tokens,
...@@ -336,7 +343,7 @@ class StatLogger: ...@@ -336,7 +343,7 @@ class StatLogger:
"Avg generation throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, " "Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, " "Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%", "CPU KV cache usage: %.1f%%.",
prompt_throughput, prompt_throughput,
generation_throughput, generation_throughput,
stats.num_running_sys, stats.num_running_sys,
......
...@@ -131,10 +131,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -131,10 +131,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
new_char_count = self.detokenizer.decode_sequence_inplace( new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params) seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence( self.stop_checker.maybe_stop_sequence(
seq, seq,
new_char_count=new_char_count, new_char_count=new_char_count,
sampling_params=sampling_params) sampling_params=sampling_params,
)
if seq.is_finished(): if seq.is_finished():
break break
......
...@@ -118,8 +118,12 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -118,8 +118,12 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
seq, seq_group.sampling_params) seq, seq_group.sampling_params)
else: else:
new_char_count = 0 new_char_count = 0
self.stop_checker.maybe_stop_sequence(seq, new_char_count, self.stop_checker.maybe_stop_sequence(
seq_group.sampling_params) seq,
new_char_count,
seq_group.sampling_params,
lora_req=seq_group.lora_request,
)
# Non-beam search case # Non-beam search case
if not seq_group.sampling_params.use_beam_search: if not seq_group.sampling_params.use_beam_search:
......
...@@ -2,6 +2,7 @@ from typing import Callable, Optional ...@@ -2,6 +2,7 @@ from typing import Callable, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus from vllm.sequence import Sequence, SequenceStatus
...@@ -16,11 +17,23 @@ class StopChecker: ...@@ -16,11 +17,23 @@ class StopChecker:
def __init__(self, max_model_len: int, def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]): PreTrainedTokenizer]):
self.max_model_len = max_model_len # Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq self.get_tokenizer_for_seq = get_tokenizer_for_seq
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int, def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
sampling_params: SamplingParams) -> None: if lora_req and lora_req.long_lora_max_len:
return lora_req.long_lora_max_len
else:
return self._max_model_len
def maybe_stop_sequence(
self,
seq: Sequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences. """Stop the finished sequences.
new_char_count is the number of chars added to the new_char_count is the number of chars added to the
...@@ -35,6 +48,11 @@ class StopChecker: ...@@ -35,6 +48,11 @@ class StopChecker:
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos) if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id): and seq.get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
...@@ -59,7 +77,7 @@ class StopChecker: ...@@ -59,7 +77,7 @@ class StopChecker:
return return
# Check if the sequence has reached max_model_len. # Check if the sequence has reached max_model_len.
if seq.get_len() > self.max_model_len: if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return return
......
from typing import List from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from vllm.sequence import SamplerOutput, SequenceGroupOutput from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group( def create_output_by_sequence_group(
sampler_outputs: List[SamplerOutput], outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]: num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by """Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step]. [step][sequence group] into [sequence group][step].
""" """
output_by_sequence_group: List[List[SamplerOutput]] = [ output_by_sequence_group: List[List[SequenceGroupOutput]] = [
[] for _ in range(num_seq_groups) [] for _ in range(num_seq_groups)
] ]
for step in sampler_outputs: for step in outputs:
for i, sequence_group_output in enumerate(step): for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output) output_by_sequence_group[i].append(sequence_group_output)
......
from typing import List, Optional, Union from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
TextTokensPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter from vllm.utils import Counter, deprecate_kwargs
logger = init_logger(__name__)
class LLM: class LLM:
...@@ -23,10 +30,6 @@ class LLM: ...@@ -23,10 +30,6 @@ class LLM:
this class generates texts from the model, using an intelligent batching this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management. mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args: Args:
model: The name or path of a HuggingFace Transformers model. model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer. tokenizer: The name or path of a HuggingFace Transformers tokenizer.
...@@ -75,8 +78,26 @@ class LLM: ...@@ -75,8 +78,26 @@ class LLM:
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode. to eager mode.
disable_custom_all_reduce: See ParallelConfig disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Note:
This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
""" """
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
cls.DEPRECATE_LEGACY = True
yield
cls.DEPRECATE_LEGACY = False
def __init__( def __init__(
self, self,
model: str, model: str,
...@@ -134,126 +155,415 @@ class LLM: ...@@ -134,126 +155,415 @@ class LLM:
) -> None: ) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer self.llm_engine.tokenizer.tokenizer = tokenizer
@overload # LEGACY: single (prompt + optional token ids)
def generate( def generate(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: str,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def generate(
self,
prompts: List[str],
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None, List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def generate(
self,
prompts: Optional[str] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def generate(
self,
prompts: Optional[List[str]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def generate(
self,
prompts: None,
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload
def generate(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def generate(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
prompts: A list of prompts to generate completions for. inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt. When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt. prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns: Returns:
A list of `RequestOutput` objects containing the generated A list of `RequestOutput` objects containing the
completions in the same order as the input prompts. generated completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (prompts is not None and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
if prompts is not None: Note:
num_requests = len(prompts) Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else: else:
assert prompt_token_ids is not None inputs = cast(
num_requests = len(prompt_token_ids) Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
elif isinstance(sampling_params, self._validate_and_add_requests(
list) and len(sampling_params) != num_requests: inputs=inputs,
raise ValueError("The lengths of prompts and sampling_params " params=sampling_params,
lora_request=lora_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
@overload # LEGACY: single (prompt + optional token ids)
def encode(
self,
prompts: str,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def encode(
self,
prompts: List[str],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def encode(
self,
prompts: Optional[str] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def encode(
self,
prompts: Optional[List[str]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def encode(
self,
prompts: None,
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload
def encode(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[EmbeddingRequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def encode(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts.
Note:
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
inputs=inputs,
params=pooling_params,
lora_request=lora_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
# LEGACY
def _convert_v1_inputs(
self,
prompts: Optional[Union[str, List[str]]],
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
multi_modal_data: Optional[MultiModalData],
):
# skip_tokenizer_init is now checked in engine
if prompts is not None:
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
if prompt_token_ids is not None:
prompt_token_ids = [
p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
]
num_requests = None
if prompts is not None:
num_requests = len(prompts)
if prompt_token_ids is not None:
if (num_requests is not None
and num_requests != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
num_requests = len(prompt_token_ids)
if num_requests is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
inputs: List[PromptInputs] = []
for i in range(num_requests):
if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i])
else:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
if multi_modal_data is not None:
item["multi_modal_data"] = multi_modal_data
inputs.append(item)
return inputs
def _validate_and_add_requests(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[LoRARequest],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
inputs = [inputs]
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.") "must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
for i in range(num_requests): for i, request_inputs in enumerate(inputs):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request( self._add_request(
prompt, request_inputs,
sampling_params[i] params[i] if isinstance(params, Sequence) else params,
if isinstance(sampling_params, list) else sampling_params,
token_ids,
lora_request=lora_request, lora_request=lora_request,
# Get ith image while maintaining the batch dim.
multi_modal_data=MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0))
if multi_modal_data else None,
) )
return self._run_engine(use_tqdm)
def _add_request( def _add_request(
self, self,
prompt: Optional[str], inputs: PromptInputs,
sampling_params: SamplingParams, params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, self.llm_engine.add_request(request_id,
prompt, inputs,
sampling_params, params,
prompt_token_ids, lora_request=lora_request)
lora_request=lora_request,
multi_modal_data=multi_modal_data)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests() num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, pbar = tqdm(
desc="Processed prompts", total=num_requests,
dynamic_ncols=True) desc="Processed prompts",
dynamic_ncols=True,
postfix=f"Generation Speed: {0:.2f} toks/s",
)
# Run the engine. # Run the engine.
outputs: List[RequestOutput] = [] outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
for output in step_outputs: for output in step_outputs:
if output.finished: if output.finished:
outputs.append(output) outputs.append(output)
if use_tqdm: if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
total_toks += sum(
len(stp.token_ids) for stp in output.outputs)
spd = total_toks / pbar.format_dict["elapsed"]
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
pbar.update(1) pbar.update(1)
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
# Sort the outputs by request ID. # Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id)) return sorted(outputs, key=lambda x: int(x.request_id))
return outputs
\ No newline at end of file
...@@ -4,7 +4,7 @@ import inspect ...@@ -4,7 +4,7 @@ import inspect
import re import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Set from typing import Optional, Set
import fastapi import fastapi
import uvicorn import uvicorn
...@@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine ...@@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, ErrorResponse) CompletionRequest,
EmbeddingRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -32,9 +34,11 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds ...@@ -32,9 +34,11 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
logger = init_logger(__name__) logger = init_logger(__name__)
_running_tasks: Set[asyncio.Task[Any]] = set() _running_tasks: Set[asyncio.Task] = set()
@asynccontextmanager @asynccontextmanager
...@@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@app.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
return JSONResponse(content=generator.model_dump())
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
...@@ -139,6 +154,8 @@ if __name__ == "__main__": ...@@ -139,6 +154,8 @@ if __name__ == "__main__":
@app.middleware("http") @app.middleware("http")
async def authentication(request: Request, call_next): async def authentication(request: Request, call_next):
root_path = "" if args.root_path is None else args.root_path root_path = "" if args.root_path is None else args.root_path
if request.method == "OPTIONS":
return await call_next(request)
if not request.url.path.startswith(f"{root_path}/v1"): if not request.url.path.startswith(f"{root_path}/v1"):
return await call_next(request) return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token: if request.headers.get("Authorization") != "Bearer " + token:
...@@ -164,16 +181,34 @@ if __name__ == "__main__": ...@@ -164,16 +181,34 @@ if __name__ == "__main__":
served_model_names = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model_names = [args.model] served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args( engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
event_loop: Optional[asyncio.AbstractEventLoop]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config = event_loop.run_until_complete(engine.get_model_config())
else:
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role, args.response_role,
args.lora_modules, args.lora_modules,
args.chat_template) args.chat_template)
openai_serving_completion = OpenAIServingCompletion( openai_serving_completion = OpenAIServingCompletion(
engine, served_model_names, args.lora_modules) engine, model_config, served_model_names, args.lora_modules)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
app.root_path = args.root_path app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
......
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time import time
from typing import Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
import openai.types.chat
import torch import torch
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated # pydantic needs the TypedDict from typing_extensions
from typing_extensions import Annotated, Required, TypedDict
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
type: Required[str]
"""The type of the content part."""
ChatCompletionContentPartParam = Union[
openai.types.chat.ChatCompletionContentPartParam,
CustomChatCompletionContentPartParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
ChatCompletionMessageParam = Union[
openai.types.chat.ChatCompletionMessageParam,
CustomChatCompletionMessageParam]
class OpenAIBaseModel(BaseModel): class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields # OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
...@@ -74,7 +109,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -74,7 +109,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None top_logprobs: Optional[int] = 0
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
n: Optional[int] = 1 n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
...@@ -157,8 +192,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -157,8 +192,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_sampling_params(self) -> SamplingParams: def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs: # We now allow logprobs being true without top_logrobs.
raise ValueError("Top logprobs must be set when logprobs is.")
logits_processors = None logits_processors = None
if self.logit_bias: if self.logit_bias:
...@@ -216,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -216,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice').") "('guided_json', 'guided_regex' or 'guided_choice').")
return data return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif not 0 <= data["top_logprobs"] <= 20:
raise ValueError(
"`top_logprobs` must be a value in the interval [0, 20].")
return data
class CompletionRequest(OpenAIBaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
...@@ -362,8 +409,35 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -362,8 +409,35 @@ class CompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice').") "('guided_json', 'guided_regex' or 'guided_choice').")
return data return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
raise ValueError(("if passed, `logprobs` must be a value",
" in the interval [0, 5]."))
return data
class LogProbs(OpenAIBaseModel): class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
dimensions: Optional[int] = None
user: Optional[str] = None
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
...@@ -373,7 +447,7 @@ class LogProbs(OpenAIBaseModel): ...@@ -373,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
class CompletionResponseChoice(OpenAIBaseModel): class CompletionResponseChoice(OpenAIBaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field( stop_reason: Optional[Union[int, str]] = Field(
default=None, default=None,
...@@ -396,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel): ...@@ -396,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
class CompletionResponseStreamChoice(OpenAIBaseModel): class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field( stop_reason: Optional[Union[int, str]] = Field(
default=None, default=None,
...@@ -416,16 +490,45 @@ class CompletionStreamResponse(OpenAIBaseModel): ...@@ -416,16 +490,45 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None) usage: Optional[UsageInfo] = Field(default=None)
class EmbeddingResponseData(BaseModel):
index: int
object: str = "embedding"
embedding: List[float]
class EmbeddingResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: List[EmbeddingResponseData]
usage: UsageInfo
class ChatMessage(OpenAIBaseModel): class ChatMessage(OpenAIBaseModel):
role: str role: str
content: str content: str
class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: Optional[List[int]] = None
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
class ChatCompletionLogProbs(OpenAIBaseModel):
content: Optional[List[ChatCompletionLogProbsContent]] = None
class ChatCompletionResponseChoice(OpenAIBaseModel): class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None stop_reason: Optional[Union[int, str]] = None
...@@ -446,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel): ...@@ -446,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
class ChatCompletionResponseStreamChoice(OpenAIBaseModel): class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None stop_reason: Optional[Union[int, str]] = None
...@@ -458,3 +561,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): ...@@ -458,3 +561,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None) usage: Optional[UsageInfo] = Field(default=None)
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url: str
# The parameteters of the request.
body: Union[ChatCompletionRequest, ]
class BatchRequestOutput(OpenAIBaseModel):
"""
The per-line object of the batch output and error files
"""
id: str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str
response: Optional[ChatCompletionResponse]
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]
import argparse
import asyncio
import sys
from io import StringIO
import aiohttp
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
ChatCompletionResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
logger = init_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible batch runner.")
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help=
"The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.")
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.")
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, "r") as f:
return f.read()
async def write_file(path_or_url: str, data: str) -> None:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.put(path_or_url, data=data.encode("utf-8")):
pass
else:
# We should make this async, but as long as this is always run as a
# standalone program, blocking the event loop won't effect performance
# in this particular case.
with open(path_or_url, "w") as f:
f.write(data)
async def run_request(chat_serving: OpenAIServingChat,
request: BatchRequestInput) -> BatchRequestOutput:
chat_request = request.body
chat_response = await chat_serving.create_chat_completion(chat_request)
if isinstance(chat_response, ChatCompletionResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=chat_response,
error=None,
)
else:
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=None,
error=chat_response,
)
return batch_output
async def main(args):
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config()
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
args.response_role,
)
# Submit all requests in the file to the engine "concurrently".
response_futures = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
request = BatchRequestInput.model_validate_json(request_json)
response_futures.append(run_request(openai_serving_chat, request))
responses = await asyncio.gather(*response_futures)
output_buffer = StringIO()
for response in responses:
print(response.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
await write_file(args.output_file, output_buffer.read().strip())
# Temporary workaround for https://github.com/vllm-project/vllm/issues/4789
sys.exit(0)
if __name__ == "__main__":
args = parse_args()
logger.info("vLLM API server version %s", vllm.__version__)
logger.info("args: %s", args)
asyncio.run(main(args))
import asyncio
import codecs import codecs
import time import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, from dataclasses import dataclass
Optional, Tuple, TypedDict, Union, final) from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
Optional)
from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final
from fastapi import Request from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam, from openai.types.chat import ChatCompletionContentPartTextParam
ChatCompletionRole)
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionContentPartParam, ChatCompletionLogProb,
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo) UsageInfo)
...@@ -20,6 +24,7 @@ from vllm.logger import init_logger ...@@ -20,6 +24,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,46 +36,96 @@ class ConversationMessage(TypedDict): ...@@ -31,46 +36,96 @@ class ConversationMessage(TypedDict):
content: str content: str
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
class OpenAIServingChat(OpenAIServing): class OpenAIServingChat(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
response_role: str, response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None, lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None): chat_template: Optional[str] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules, lora_modules=lora_modules)
await_post_init=self._load_chat_template(
chat_template=chat_template))
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template)
def _parse_chat_message_content( def _load_chat_template(self, chat_template: Optional[str]):
self, tokenizer = self.tokenizer
role: ChatCompletionRole,
content: Optional[Union[str, if chat_template is not None:
Iterable[ChatCompletionContentPartParam]]], try:
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]: with open(chat_template, "r") as f:
if content is None: tokenizer.chat_template = f.read()
return [], [] except OSError as e:
if isinstance(content, str): JINJA_CHARS = "{}\n"
return [ConversationMessage(role=role, content=content)], [] if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
def _parse_chat_message_content_parts(
self,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult:
texts: List[str] = [] texts: List[str] = []
for _, part in enumerate(content):
if part["type"] == "text": for _, part in enumerate(parts):
text = part["text"] part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text) texts.append(text)
else: else:
raise NotImplementedError(f"Unknown part type: {part['type']}") raise NotImplementedError(f"Unknown part type: {part_type}")
messages = [ConversationMessage(role=role, content="\n".join(texts))]
return [ConversationMessage(role=role, content="\n".join(texts))], [] return ChatMessageParseResult(messages=messages)
def _parse_chat_message_content(
self,
message: ChatCompletionMessageParam,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages)
return self._parse_chat_message_content_parts(role, content)
async def create_chat_completion( async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request self,
request: ChatCompletionRequest,
raw_request: Optional[Request] = None
) -> Union[ErrorResponse, AsyncGenerator[str, None], ) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]: ChatCompletionResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
...@@ -89,11 +144,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -89,11 +144,10 @@ class OpenAIServingChat(OpenAIServing):
try: try:
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
for m in request.messages: for msg in request.messages:
messages, _ = self._parse_chat_message_content( parsed_msg = self._parse_chat_message_content(msg)
m["role"], m["content"])
conversation.extend(messages) conversation.extend(parsed_msg.messages)
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
...@@ -108,7 +162,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -108,7 +162,7 @@ class OpenAIServingChat(OpenAIServing):
try: try:
# Tokenize/detokenize depending on prompt format (string/token list) # Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize( prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt) request, prompt=prompt, add_special_tokens=False)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config() decoding_config = await self.engine.get_decoding_config()
...@@ -126,9 +180,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -126,9 +180,15 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt_text, sampling_params, result_generator = self.engine.generate(
request_id, prompt_ids, {
lora_request) "prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
request_id,
lora_request,
)
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
...@@ -227,11 +287,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -227,11 +287,10 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens[i]:] if output.logprobs else None previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs: if request.logprobs:
logprobs = self._create_logprobs( logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids, token_ids=delta_token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.top_logprobs,
initial_text_offset=len(previous_texts[i]),
) )
else: else:
logprobs = None logprobs = None
...@@ -289,7 +348,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -289,7 +348,7 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
async def chat_completion_full_generator( async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request, self, request: ChatCompletionRequest, raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], request_id: str, result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage] conversation: List[ConversationMessage]
) -> Union[ErrorResponse, ChatCompletionResponse]: ) -> Union[ErrorResponse, ChatCompletionResponse]:
...@@ -299,7 +358,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -299,7 +358,7 @@ class OpenAIServingChat(OpenAIServing):
final_res: Optional[RequestOutput] = None final_res: Optional[RequestOutput] = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await self.engine.abort(request_id) await self.engine.abort(request_id)
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
...@@ -314,10 +373,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -314,10 +373,10 @@ class OpenAIServingChat(OpenAIServing):
top_logprobs = output.logprobs top_logprobs = output.logprobs
if request.logprobs: if request.logprobs:
logprobs = self._create_logprobs( logprobs = self._create_chat_logprobs(
token_ids=token_ids, token_ids=token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.top_logprobs,
) )
else: else:
logprobs = None logprobs = None
...@@ -327,8 +386,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -327,8 +386,7 @@ class OpenAIServingChat(OpenAIServing):
message=ChatMessage(role=role, content=output.text), message=ChatMessage(role=role, content=output.text),
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason, stop_reason=output.stop_reason)
)
choices.append(choice_data) choices.append(choice_data)
if request.echo: if request.echo:
...@@ -359,34 +417,50 @@ class OpenAIServingChat(OpenAIServing): ...@@ -359,34 +417,50 @@ class OpenAIServingChat(OpenAIServing):
return response return response
async def _load_chat_template(self, chat_template: Optional[str]): def _get_top_logprobs(
while self.tokenizer is None: self, logprobs: Dict[int, Logprob],
# Give the parent class time to load the tokenizer top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
await asyncio.sleep(0.1) return [
tokenizer = self.tokenizer ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
if chat_template is not None: logprob=max(p[1].logprob, -9999.0),
try: bytes=list(
with open(chat_template, "r") as f: self._get_decoded_token(p[1],
tokenizer.chat_template = f.read() p[0]).encode("utf-8",
except OSError as e: errors="replace")))
JINJA_CHARS = "{}\n" for i, p in enumerate(logprobs.items())
if not any(c in chat_template for c in JINJA_CHARS): if top_logprobs and i < top_logprobs
msg = (f"The supplied chat template ({chat_template}) " ]
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}") def _create_chat_logprobs(
raise ValueError(msg) from e self,
token_ids: GenericSequence[int],
# If opening a file fails, set chat template to be args to top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
# ensure we decode so our escape are interpreted correctly num_output_top_logprobs: Optional[int] = None,
tokenizer.chat_template = codecs.decode( ) -> ChatCompletionLogProbs:
chat_template, "unicode_escape") """Create OpenAI-style logprobs."""
logger.info("Using supplied chat template:\n%s", logprobs_content = []
tokenizer.chat_template)
elif tokenizer.chat_template is not None: for i, token_id in enumerate(token_ids):
logger.info("Using default chat template:\n%s", step_top_logprobs = top_logprobs[i]
tokenizer.chat_template) if step_top_logprobs is None:
else: logprobs_content.append(
logger.warning( ChatCompletionLogProbsContent(
"No chat template provided. Chat API will not work.") token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=step_top_logprobs[token_id].decoded_token,
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
bytes=list(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
return ChatCompletionLogProbs(content=logprobs_content)
import time import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple) Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest, # yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
LogProbs, UsageInfo) UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -24,7 +31,7 @@ logger = init_logger(__name__) ...@@ -24,7 +31,7 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int] TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]] TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[ TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]: def parse_prompt_format(prompt) -> Tuple[bool, list]:
...@@ -52,11 +59,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: ...@@ -52,11 +59,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
def __init__(self, def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
engine: AsyncLLMEngine,
served_model_names: List[str], served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None): lora_modules: Optional[List[LoRAModulePath]]):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules)
...@@ -118,12 +125,17 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -118,12 +125,17 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens) truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats prompt_ids, prompt_text = prompt_formats
generators.append( generator = self.engine.generate(
self.engine.generate(prompt_text, {
sampling_params, "prompt": prompt_text,
f"{request_id}-{i}", "prompt_token_ids": prompt_ids
prompt_token_ids=prompt_ids, },
lora_request=lora_request)) sampling_params,
f"{request_id}-{i}",
lora_request=lora_request,
)
generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -229,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -229,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
i]:] if output.logprobs else None i]:] if output.logprobs else None
if request.logprobs is not None: if request.logprobs is not None:
logprobs = self._create_logprobs( logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids, token_ids=delta_token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
...@@ -311,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -311,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert top_logprobs is not None, ( assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs " "top_logprobs must be provided when logprobs "
"is requested") "is requested")
logprobs = self._create_logprobs( logprobs = self._create_completion_logprobs(
token_ids=token_ids, token_ids=token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
...@@ -345,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -345,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
choices=choices, choices=choices,
usage=usage, usage=usage,
) )
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: List[int] = []
out_token_logprobs: List[Optional[float]] = []
out_tokens: List[str] = []
out_top_logprobs: List[Optional[Dict[str, float]]] = []
last_token_len = 0
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0]):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
})
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
import time
from typing import AsyncIterator, List, Optional, Tuple
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> EmbeddingResponse:
data = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
embedding_data = EmbeddingResponseData(
index=idx, embedding=final_res.outputs.embedding)
data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
class OpenAIServingEmbedding(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None)
self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# Return error for unsupported features.
if request.encoding_format == "base64":
return self.create_error_response(
"base64 encoding is not currently supported")
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
try:
prompt_is_tokens, prompts = parse_prompt_format(request.input)
pooling_params = request.to_pooling_params()
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt=prompt)
prompt_ids, prompt_text = prompt_formats
generator = self.engine.encode(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
pooling_params,
f"{request_id}-{i}",
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def _check_embedding_mode(self, embedding_mode: bool):
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")
import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse, CompletionRequest,
LogProbs, ModelCard, ModelList, EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission) ModelPermission)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -29,13 +29,24 @@ class LoRAModulePath: ...@@ -29,13 +29,24 @@ class LoRAModulePath:
class OpenAIServing: class OpenAIServing:
def __init__(self, def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
engine: AsyncLLMEngine,
served_model_names: List[str], served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]]):
await_post_init: Optional[Awaitable[Any]] = None): super().__init__()
self.engine = engine self.engine = engine
self.max_model_len = model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
truncation_side="left")
self.served_model_names = served_model_names self.served_model_names = served_model_names
if lora_modules is None: if lora_modules is None:
self.lora_requests = [] self.lora_requests = []
else: else:
...@@ -47,38 +58,6 @@ class OpenAIServing: ...@@ -47,38 +58,6 @@ class OpenAIServing:
) for i, lora in enumerate(lora_modules, start=1) ) for i, lora in enumerate(lora_modules, start=1)
] ]
self.max_model_len = 0
# Lazy initialized
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init(await_post_init))
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init(await_post_init))
async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
tokenizer_revision=engine_model_config.tokenizer_revision,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
if await_post_init is not None:
await await_post_init
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
...@@ -96,50 +75,6 @@ class OpenAIServing: ...@@ -96,50 +75,6 @@ class OpenAIServing:
model_cards.extend(lora_cards) model_cards.extend(lora_cards)
return ModelList(data=model_cards) return ModelList(data=model_cards)
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None)
else:
token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
return logprobs
def create_error_response( def create_error_response(
self, self,
message: str, message: str,
...@@ -163,7 +98,8 @@ class OpenAIServing: ...@@ -163,7 +98,8 @@ class OpenAIServing:
return json_str return json_str
async def _check_model( async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest] self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
...@@ -175,7 +111,8 @@ class OpenAIServing: ...@@ -175,7 +111,8 @@ class OpenAIServing:
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora( def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest] self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[LoRARequest]: ) -> Optional[LoRARequest]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
...@@ -186,12 +123,14 @@ class OpenAIServing: ...@@ -186,12 +123,14 @@ class OpenAIServing:
raise ValueError(f"The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( def _validate_prompt_and_tokenize(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest,
prompt: Optional[str] = None, EmbeddingRequest],
prompt_ids: Optional[List[int]] = None, prompt: Optional[str] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None prompt_ids: Optional[List[int]] = None,
) -> Tuple[List[int], str]: truncate_prompt_tokens: Optional[Annotated[int,
Field(ge=1)]] = None,
add_special_tokens: bool = True) -> Tuple[List[int], str]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids): if (prompt and prompt_ids):
...@@ -199,10 +138,19 @@ class OpenAIServing: ...@@ -199,10 +138,19 @@ class OpenAIServing:
"Only one of prompt or prompt_ids should be provided.") "Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None: if prompt_ids is None:
tokenizer_kwargs = {} if truncate_prompt_tokens is None else { # When using OpenAIServingChat for chat completions, the
"truncation": True, # special tokens (e.g., BOS) have already been added by the
"max_length": truncate_prompt_tokens, # chat template. Therefore, we do not need to add them again.
# Set add_special_tokens to False to avoid adding the BOS tokens
# again.
tokenizer_kwargs: Dict[str, Any] = {
"add_special_tokens": add_special_tokens
} }
if truncate_prompt_tokens is not None:
tokenizer_kwargs.update({
"truncation": True,
"max_length": truncate_prompt_tokens,
})
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None: elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:] input_ids = prompt_ids[-truncate_prompt_tokens:]
...@@ -213,6 +161,16 @@ class OpenAIServing: ...@@ -213,6 +161,16 @@ class OpenAIServing:
prompt_ids) prompt_ids)
token_num = len(input_ids) token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.", )
return input_ids, input_text
if request.max_tokens is None: if request.max_tokens is None:
if token_num >= self.max_model_len: if token_num >= self.max_model_len:
raise ValueError( raise ValueError(
...@@ -232,3 +190,8 @@ class OpenAIServing: ...@@ -232,3 +190,8 @@ class OpenAIServing:
f"Please reduce the length of the messages or completion.", ) f"Please reduce the length of the messages or completion.", )
else: else:
return input_ids, input_text return input_ids, input_text
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return self.tokenizer.decode(token_id)
...@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional ...@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
VLLM_HOST_IP: str = "" VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_USE_MODELSCOPE: bool = False VLLM_USE_MODELSCOPE: bool = False
VLLM_INSTANCE_ID: Optional[str] = None VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
...@@ -21,6 +22,7 @@ if TYPE_CHECKING: ...@@ -21,6 +22,7 @@ if TYPE_CHECKING:
VLLM_DO_NOT_TRACK: bool = False VLLM_DO_NOT_TRACK: bool = False
VLLM_USAGE_SOURCE: str = "" VLLM_USAGE_SOURCE: str = ""
VLLM_CONFIGURE_LOGGING: int = 1 VLLM_CONFIGURE_LOGGING: int = 1
VLLM_LOGGING_LEVEL: str = "INFO"
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
...@@ -96,6 +98,12 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -96,6 +98,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
'VLLM_HOST_IP': 'VLLM_HOST_IP':
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
# used in distributed environment to manually set the communication port
# '0' is used to make mypy happy
'VLLM_PORT':
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,
# If true, will load models from ModelScope instead of Hugging Face Hub. # If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers # note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE": "VLLM_USE_MODELSCOPE":
...@@ -145,7 +153,7 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -145,7 +153,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# S3 access information, used for tensorizer to load model from S3 # S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID": "S3_ACCESS_KEY_ID":
lambda: os.environ.get("S3_ACCESS_KEY", None), lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
"S3_SECRET_ACCESS_KEY": "S3_SECRET_ACCESS_KEY":
lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
"S3_ENDPOINT_URL": "S3_ENDPOINT_URL":
...@@ -171,6 +179,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -171,6 +179,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_CONFIG_PATH": "VLLM_LOGGING_CONFIG_PATH":
lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"),
# this is used for configuring the default logging level
"VLLM_LOGGING_LEVEL":
lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"),
# Trace function calls # Trace function calls
# If set to 1, vllm will trace function calls # If set to 1, vllm will trace function calls
# Useful for debugging # Useful for debugging
......
import asyncio
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -13,6 +14,16 @@ logger = init_logger(__name__) ...@@ -13,6 +14,16 @@ logger = init_logger(__name__)
class DistributedGPUExecutor(GPUExecutor): class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations.""" """Abstract superclass of multi-GPU executor implementations."""
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
super().__init__(*args, **kwargs)
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks. """Determine the number of available KV blocks.
...@@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor):
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks) num_cpu_blocks=num_cpu_blocks)
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: def execute_model(
all_outputs = self._run_workers("execute_model", self,
driver_args=args, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
driver_kwargs=kwargs) if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
return all_outputs[0] return self._driver_execute_model(execute_model_req)
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
...@@ -77,39 +103,95 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -77,39 +103,95 @@ class DistributedGPUExecutor(GPUExecutor):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self._run_workers("list_loras") return self._run_workers("list_loras")
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_sharded_state",
path=path,
pattern=pattern,
max_size=max_size)
@abstractmethod
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[Tuple[Any, ...]] = None, async_run_remote_workers_only: bool = False,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
raise NotImplementedError
@abstractmethod
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise NotImplementedError raise NotImplementedError
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())
# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
@abstractmethod @abstractmethod
async def _run_workers_async( async def _driver_execute_model_async(
self, self,
method: str, execute_model_req: Optional[ExecuteModelRequest] = None
*args, ) -> List[SamplerOutput]:
driver_args: Optional[Tuple[Any, ...]] = None, """Execute the model asynchronously in the driver worker.
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
raise NotImplementedError
async def execute_model_async(self, *args, Passing None will cause the driver to stop the model execution
**kwargs) -> List[SamplerOutput]: loop running in each of the remote workers.
all_outputs = await self._run_workers_async("execute_model", """
driver_args=args, raise NotImplementedError
driver_kwargs=kwargs)
# Only the driver worker returns the sampling results. @abstractmethod
return all_outputs[0] async def _start_worker_execution_loop(self):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise NotImplementedError
...@@ -74,6 +74,10 @@ class ExecutorBase(ABC): ...@@ -74,6 +74,10 @@ class ExecutorBase(ABC):
"""Executes at least one model step on the given sequences.""" """Executes at least one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return
@abstractmethod @abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError raise NotImplementedError
...@@ -109,6 +113,10 @@ class ExecutorAsyncBase(ExecutorBase): ...@@ -109,6 +113,10 @@ class ExecutorAsyncBase(ExecutorBase):
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
return
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an """Checks if the executor is healthy. If not, it should raise an
exception.""" exception."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment