"vllm/vscode:/vscode.git/clone" did not exist on "c64861d63c1a5362bfad443daf7a096f1bcfd1e4"
Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
......@@ -2,19 +2,9 @@
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import os
from typing import Dict, Optional, Sequence
from typing import Sequence
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger
from .parallel_state import get_cpu_world_group, get_local_rank
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator):
......@@ -56,81 +46,3 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c).cpu().item()
# why do we need this cache?
# 1. we can have runtime checks for P2P access, where every process checks
# P2P access to all other GPUs. Unfortunately, the test might cost many
# (world_size * world_size) cuda context, and reduce the memory available
# for the model. see https://github.com/vllm-project/vllm/issues/3821
# 2. alternatively, we can have a p2p map that is generated by the master
# process and broadcasted to all other processes. This still requires
# #world_size of cuda context, belonging to the master process, on each GPU.
# 3. we can have a cache file, that records the p2p access status. The first
# time the master process checks the p2p access, it will generate the cache
# file, at the cost of #world_size of cuda context. Later on, all processes
# can read the cache file to check the p2p access status without any cost of
# additional cuda context.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(i: int, j: int) -> bool:
"""Check if GPU i can access GPU j."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"]
is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
path = os.path.expanduser(
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)
os.makedirs(os.path.dirname(path), exist_ok=True)
if (not is_distributed or get_local_rank() == 0) \
and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache for in %s", path)
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
_i, _j) and _can_actually_p2p(_i, _j)
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group)
logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"]
import argparse
import dataclasses
import json
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
......@@ -34,11 +35,13 @@ class EngineArgs:
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
distributed_executor_backend: Optional[str] = None
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
disable_sliding_window: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
......@@ -48,6 +51,7 @@ class EngineArgs:
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
......@@ -62,6 +66,7 @@ class EngineArgs:
max_lora_rank: int = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
......@@ -83,6 +88,7 @@ class EngineArgs:
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
......@@ -166,8 +172,8 @@ class EngineArgs:
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave which assumes tensorizer_uri is set to the location of '
'the serialized weights.')
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n')
parser.add_argument(
'--dtype',
type=str,
......@@ -186,12 +192,11 @@ class EngineArgs:
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
......@@ -220,10 +225,17 @@ class EngineArgs:
' Can be overridden per request via guided_decoding_backend'
' parameter.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='Use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=EngineArgs.distributed_executor_backend,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--worker-use-ray',
action='store_true',
help='Deprecated, use --distributed-executor-backend=ray.')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
......@@ -256,6 +268,10 @@ class EngineArgs:
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching.')
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
'capping to sliding window size')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2.')
......@@ -320,6 +336,11 @@ class EngineArgs:
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument('--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
......@@ -331,9 +352,9 @@ class EngineArgs:
help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
'(DEPRECATED. Use --max-seq-len-to-capture instead'
')')
parser.add_argument('--max-seq_len-to-capture',
parser.add_argument('--max-seq-len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
......@@ -388,6 +409,17 @@ class EngineArgs:
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--long-lora-scaling-factors',
type=nullable_str,
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can '
'be different from base model scaling factor '
'- see eg. Long LoRA) to allow for multiple '
'LoRA adapters trained with those scaling '
'factors to be used at the same time. If not '
'specified, only adapters trained with the '
'base model scaling factor are allowed.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
......@@ -467,6 +499,13 @@ class EngineArgs:
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
......@@ -508,7 +547,7 @@ class EngineArgs:
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
def from_cli_args(cls, args: argparse.Namespace):
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
......@@ -520,10 +559,11 @@ class EngineArgs:
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.dtype, self.seed, self.revision,
self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_seq_len_to_capture, self.max_logprobs,
self.code_revision, self.rope_scaling, self.tokenizer_revision,
self.max_model_len, self.quantization,
self.quantization_param_path, self.enforce_eager,
self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.max_logprobs, self.disable_sliding_window,
self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
......@@ -532,14 +572,18 @@ class EngineArgs:
model_config.get_sliding_window(),
self.enable_prefix_caching)
parallel_config = ParallelConfig(
self.pipeline_parallel_size, self.tensor_parallel_size,
self.worker_use_ray, self.max_parallel_loading_workers,
self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray,
self.max_parallel_loading_workers,
self.disable_custom_all_reduce,
TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight)
),
self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend)
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
......@@ -547,6 +591,8 @@ class EngineArgs:
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
......@@ -564,12 +610,14 @@ class EngineArgs:
speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
......@@ -599,6 +647,13 @@ class EngineArgs:
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled
and not scheduler_config.use_v2_block_manager):
raise ValueError(
"Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window.")
return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
......
import asyncio
import time
from functools import partial
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union)
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
......@@ -12,11 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
......@@ -47,15 +49,16 @@ def _raise_exception_on_finish(
class AsyncStream:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None:
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
......@@ -71,7 +74,7 @@ class AsyncStream:
def __aiter__(self):
return self
async def __anext__(self) -> RequestOutput:
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
result = await self._queue.get()
if isinstance(result, Exception):
raise result
......@@ -108,7 +111,8 @@ class RequestTracker:
self.abort_request(rid)
def process_request_output(self,
request_output: RequestOutput,
request_output: Union[RequestOutput,
EmbeddingRequestOutput],
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
......@@ -196,7 +200,8 @@ class RequestTracker:
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
async def step_async(self) -> List[RequestOutput]:
async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
......@@ -230,66 +235,77 @@ class _AsyncLLMEngine(LLMEngine):
# Log stats.
self.do_log_stats(scheduler_outputs, output)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()
return request_outputs
async def encode_request_async(
async def process_model_inputs_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = await tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
prompt=inputs["prompt"],
lora_request=lora_request)
return prompt_token_ids
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
async def check_health_async(self) -> None:
self.model_executor.check_health()
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMEngine to the caller.
"""An asynchronous wrapper for :class:`LLMEngine`.
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to make it
asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The :class:`LLMEngine` is kicked by the
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the :class:`LLMEngine` to the caller.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
......@@ -303,8 +319,8 @@ class AsyncLLMEngine:
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
*args: Arguments for :class:`LLMEngine`.
**kwargs: Arguments for :class:`LLMEngine`.
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
......@@ -327,7 +343,7 @@ class AsyncLLMEngine:
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
self._background_loop_unshielded: Optional[asyncio.Task] = None
self.start_engine_loop = start_engine_loop
self._errored_with: Optional[BaseException] = None
......@@ -344,27 +360,31 @@ class AsyncLLMEngine:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert not engine_config.parallel_config.worker_use_ray, (
"Ray is not supported with the CPU backend.")
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
engine_config.parallel_config.worker_use_ray,
distributed_executor_backend == "ray",
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,
......@@ -510,27 +530,31 @@ class AsyncLLMEngine:
async def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
shortened_token_ids = prompt_token_ids
if self.max_log_len is not None:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len]
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
shortened_token_ids = shortened_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"sampling_params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt,
sampling_params, shortened_token_ids, lora_request)
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)
if not self.is_running:
if self.start_engine_loop:
......@@ -546,39 +570,33 @@ class AsyncLLMEngine:
arrival_time = time.time()
if self.engine_use_ray:
prompt_token_ids = await (
self.engine.encode_request_async.remote( # type: ignore
processed_inputs = await self.engine.process_model_inputs_async \
.remote( # type: ignore
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
inputs=inputs,
lora_request=lora_request)
else:
prompt_token_ids = await self.engine.encode_request_async(
processed_inputs = await self.engine.process_model_inputs_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
inputs=inputs,
lora_request=lora_request)
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
return stream
async def generate(
self,
prompt: Optional[str],
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
......@@ -587,18 +605,16 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
request.
The output `RequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
......@@ -643,25 +659,112 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
# Preprocess the request.
arrival_time = time.time()
try:
stream = await self.add_request(
async for output in self._process_request(
request_id,
prompt,
inputs,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
):
yield LLMEngine.validate_output(output, RequestOutput)
async def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "input": "What is LLM?",
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.encode(
>>> example_input["input"],
>>> PoolingParams(),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
async for output in self._process_request(
request_id,
inputs,
pooling_params,
lora_request=lora_request,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def _process_request(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time = time.time()
stream = await self.add_request(
request_id,
inputs,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
try:
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id)
raise e
......
import time
from typing import Iterable, List, Optional, Type, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer
......@@ -18,12 +21,16 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
Sequence, SequenceGroup, SequenceGroupMetadata,
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
......@@ -47,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
return {}
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
......@@ -57,11 +67,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
The :class:`~vllm.LLM` class wraps this class for offline batched inference
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
The config arguments are derived from :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Args:
model_config: The configuration related to the LLM model.
......@@ -78,9 +88,60 @@ class LLMEngine:
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
usage_context: Specified entry point, used for usage info collection.
"""
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
"""A flag to toggle whether to validate the type of request output."""
@classmethod
@contextmanager
def enable_output_validation(cls):
cls.DO_VALIDATE_OUTPUT = True
yield
cls.DO_VALIDATE_OUTPUT = False
@classmethod
def validate_output(
cls,
output: object,
output_type: Type[_O],
) -> _O:
do_validate = cls.DO_VALIDATE_OUTPUT
if ((TYPE_CHECKING or do_validate)
and not isinstance(output, output_type)):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT
outputs_: List[_O]
if TYPE_CHECKING or do_validate:
outputs_ = []
for output in outputs:
if not isinstance(output, output_type):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
outputs_.append(output)
else:
outputs_ = outputs
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
def __init__(
self,
model_config: ModelConfig,
......@@ -101,10 +162,11 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"rope_scaling=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__,
......@@ -114,6 +176,7 @@ class LLMEngine:
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.rope_scaling,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
......@@ -146,12 +209,11 @@ class LLMEngine:
self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None
self.detokenizer = None
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
......@@ -169,7 +231,8 @@ class LLMEngine:
load_config=load_config,
)
self._initialize_kv_caches()
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
......@@ -270,6 +333,8 @@ class LLMEngine:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
......@@ -278,13 +343,15 @@ class LLMEngine:
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
executor_class = MultiprocessingGPUExecutor
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
......@@ -308,14 +375,26 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs):
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
......@@ -325,8 +404,9 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
......@@ -336,29 +416,85 @@ class LLMEngine:
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def encode_request(
def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _add_processed_request(
self,
request_id: str,
processed_inputs: LLMInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def process_model_inputs(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
"""Add a request to the engine's request pool.
......@@ -368,14 +504,14 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
multi_modal_data: Multi modal data per request.
Details:
- Set arrival_time to the current time if it is None.
......@@ -404,6 +540,30 @@ class LLMEngine:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
processed_inputs = self.process_model_inputs(request_id=request_id,
inputs=inputs,
lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
def _create_sequence_group_with_sampling(
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
......@@ -411,26 +571,6 @@ class LLMEngine:
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = self.encode_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
......@@ -443,11 +583,32 @@ class LLMEngine:
self.generation_config_fields)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request, multi_modal_data)
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
return seq_group
def _create_sequence_group_with_pooling(
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
......@@ -484,13 +645,25 @@ class LLMEngine:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
) -> None:
seq_group.embeddings = outputs[0].embeddings
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _process_model_outputs(
self,
output: List[SamplerOutput],
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]:
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
......@@ -501,7 +674,7 @@ class LLMEngine:
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
......@@ -510,6 +683,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
continue
self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
......@@ -519,18 +695,19 @@ class LLMEngine:
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
return request_outputs
def step(self) -> List[RequestOutput]:
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
......@@ -570,7 +747,7 @@ class LLMEngine:
>>> while True:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params)
>>> engine.add_request(str(req_id),prompt,sampling_params)
>>>
>>> # continue the request processing
>>> request_outputs = engine.step()
......@@ -604,6 +781,14 @@ class LLMEngine:
# Log stats.
self.do_log_stats(scheduler_outputs, output)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()
return request_outputs
def do_log_stats(
......@@ -637,12 +822,15 @@ class LLMEngine:
# KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
gpu_cache_usage_sys = 0.
if num_total_gpu is not None:
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0.
if num_total_cpu > 0:
if num_total_cpu is not None and num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
......@@ -652,6 +840,8 @@ class LLMEngine:
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
num_preemption_iter = (0 if scheduler_outputs is None else
scheduler_outputs.preempted)
# Request stats
# Latency
......@@ -716,8 +906,10 @@ class LLMEngine:
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
])
best_of_requests.append(seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
if seq_group.sampling_params is not None:
best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
for seq in seq_group.get_finished_seqs()
......@@ -743,7 +935,6 @@ class LLMEngine:
return Stats(
now=now,
# System stats
# Scheduler State
num_running_sys=num_running_sys,
......@@ -759,6 +950,7 @@ class LLMEngine:
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
num_preemption_iter=num_preemption_iter,
# Request stats
# Latency
......
......@@ -61,6 +61,10 @@ class Metrics:
labelnames=labelnames)
# Iteration stats
self.counter_num_preemption = Counter(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
......@@ -181,6 +185,7 @@ class Stats:
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
......@@ -244,6 +249,8 @@ class StatLogger:
stats.cpu_cache_usage_sys)
# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
stats.num_preemption_iter)
self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens_iter)
self._log_counter(self.metrics.counter_generation_tokens,
......@@ -336,7 +343,7 @@ class StatLogger:
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%",
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
......
......@@ -131,10 +131,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params)
sampling_params=sampling_params,
)
if seq.is_finished():
break
......
......@@ -118,8 +118,12 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
seq, seq_group.sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
seq_group.sampling_params)
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
seq_group.sampling_params,
lora_req=seq_group.lora_request,
)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
......
......@@ -2,6 +2,7 @@ from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
......@@ -16,11 +17,23 @@ class StopChecker:
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]):
self.max_model_len = max_model_len
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None:
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len:
return lora_req.long_lora_max_len
else:
return self._max_model_len
def maybe_stop_sequence(
self,
seq: Sequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
......@@ -35,6 +48,11 @@ class StopChecker:
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
......@@ -59,7 +77,7 @@ class StopChecker:
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.max_model_len:
if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
......
from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from vllm.sequence import SamplerOutput, SequenceGroupOutput
from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group(
sampler_outputs: List[SamplerOutput],
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[SamplerOutput]] = [
output_by_sequence_group: List[List[SequenceGroupOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs:
for step in outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
......
from typing import List, Optional, Union
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
import torch
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
TextTokensPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter
from vllm.utils import Counter, deprecate_kwargs
logger = init_logger(__name__)
class LLM:
......@@ -23,10 +30,6 @@ class LLM:
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
......@@ -75,8 +78,26 @@ class LLM:
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Note:
This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
"""
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
cls.DEPRECATE_LEGACY = True
yield
cls.DEPRECATE_LEGACY = False
def __init__(
self,
model: str,
......@@ -134,126 +155,415 @@ class LLM:
) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer
@overload # LEGACY: single (prompt + optional token ids)
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
prompts: str,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def generate(
self,
prompts: List[str],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def generate(
self,
prompts: Optional[str] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def generate(
self,
prompts: Optional[List[str]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def generate(
self,
prompts: None,
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload
def generate(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def generate(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (prompts is not None and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts.
if prompts is not None:
num_requests = len(prompts)
Note:
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
elif isinstance(sampling_params,
list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and sampling_params "
self._validate_and_add_requests(
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
@overload # LEGACY: single (prompt + optional token ids)
def encode(
self,
prompts: str,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def encode(
self,
prompts: List[str],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def encode(
self,
prompts: Optional[str] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def encode(
self,
prompts: Optional[List[str]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def encode(
self,
prompts: None,
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload
def encode(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[EmbeddingRequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def encode(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts.
Note:
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
inputs=inputs,
params=pooling_params,
lora_request=lora_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
# LEGACY
def _convert_v1_inputs(
self,
prompts: Optional[Union[str, List[str]]],
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
multi_modal_data: Optional[MultiModalData],
):
# skip_tokenizer_init is now checked in engine
if prompts is not None:
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
if prompt_token_ids is not None:
prompt_token_ids = [
p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
]
num_requests = None
if prompts is not None:
num_requests = len(prompts)
if prompt_token_ids is not None:
if (num_requests is not None
and num_requests != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
num_requests = len(prompt_token_ids)
if num_requests is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
inputs: List[PromptInputs] = []
for i in range(num_requests):
if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i])
else:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
if multi_modal_data is not None:
item["multi_modal_data"] = multi_modal_data
inputs.append(item)
return inputs
def _validate_and_add_requests(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[LoRARequest],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
inputs = [inputs]
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
for i, request_inputs in enumerate(inputs):
self._add_request(
prompt,
sampling_params[i]
if isinstance(sampling_params, list) else sampling_params,
token_ids,
request_inputs,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request,
# Get ith image while maintaining the batch dim.
multi_modal_data=MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0))
if multi_modal_data else None,
)
return self._run_engine(use_tqdm)
def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
inputs,
params,
lora_request=lora_request)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests,
desc="Processed prompts",
dynamic_ncols=True)
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=f"Generation Speed: {0:.2f} toks/s",
)
# Run the engine.
outputs: List[RequestOutput] = []
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
total_toks += sum(
len(stp.token_ids) for stp in output.outputs)
spd = total_toks / pbar.format_dict["elapsed"]
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs
\ No newline at end of file
return sorted(outputs, key=lambda x: int(x.request_id))
......@@ -4,7 +4,7 @@ import inspect
import re
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Any, Set
from typing import Optional, Set
import fastapi
import uvicorn
......@@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse)
CompletionRequest,
EmbeddingRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
......@@ -32,9 +34,11 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
logger = init_logger(__name__)
_running_tasks: Set[asyncio.Task[Any]] = set()
_running_tasks: Set[asyncio.Task] = set()
@asynccontextmanager
......@@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())
@app.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
return JSONResponse(content=generator.model_dump())
if __name__ == "__main__":
args = parse_args()
......@@ -139,6 +154,8 @@ if __name__ == "__main__":
@app.middleware("http")
async def authentication(request: Request, call_next):
root_path = "" if args.root_path is None else args.root_path
if request.method == "OPTIONS":
return await call_next(request)
if not request.url.path.startswith(f"{root_path}/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
......@@ -164,16 +181,34 @@ if __name__ == "__main__":
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
event_loop: Optional[asyncio.AbstractEventLoop]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config = event_loop.run_until_complete(engine.get_model_config())
else:
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model_names, args.lora_modules)
engine, model_config, served_model_names, args.lora_modules)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
......
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
import openai.types.chat
import torch
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Annotated, Required, TypedDict
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
type: Required[str]
"""The type of the content part."""
ChatCompletionContentPartParam = Union[
openai.types.chat.ChatCompletionContentPartParam,
CustomChatCompletionContentPartParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
ChatCompletionMessageParam = Union[
openai.types.chat.ChatCompletionMessageParam,
CustomChatCompletionMessageParam]
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
......@@ -74,7 +109,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
top_logprobs: Optional[int] = 0
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
......@@ -157,8 +192,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
# We now allow logprobs being true without top_logrobs.
logits_processors = None
if self.logit_bias:
......@@ -216,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif not 0 <= data["top_logprobs"] <= 20:
raise ValueError(
"`top_logprobs` must be a value in the interval [0, 20].")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
......@@ -362,8 +409,35 @@ class CompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
raise ValueError(("if passed, `logprobs` must be a value",
" in the interval [0, 5]."))
return data
class LogProbs(OpenAIBaseModel):
class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
dimensions: Optional[int] = None
user: Optional[str] = None
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
......@@ -373,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
......@@ -396,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
......@@ -416,16 +490,45 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)
class EmbeddingResponseData(BaseModel):
index: int
object: str = "embedding"
embedding: List[float]
class EmbeddingResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: List[EmbeddingResponseData]
usage: UsageInfo
class ChatMessage(OpenAIBaseModel):
role: str
content: str
class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: Optional[List[int]] = None
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
class ChatCompletionLogProbs(OpenAIBaseModel):
content: Optional[List[ChatCompletionLogProbsContent]] = None
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None
......@@ -446,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None
......@@ -458,3 +561,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url: str
# The parameteters of the request.
body: Union[ChatCompletionRequest, ]
class BatchRequestOutput(OpenAIBaseModel):
"""
The per-line object of the batch output and error files
"""
id: str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str
response: Optional[ChatCompletionResponse]
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]
import argparse
import asyncio
import sys
from io import StringIO
import aiohttp
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
ChatCompletionResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
logger = init_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible batch runner.")
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help=
"The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.")
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.")
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, "r") as f:
return f.read()
async def write_file(path_or_url: str, data: str) -> None:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.put(path_or_url, data=data.encode("utf-8")):
pass
else:
# We should make this async, but as long as this is always run as a
# standalone program, blocking the event loop won't effect performance
# in this particular case.
with open(path_or_url, "w") as f:
f.write(data)
async def run_request(chat_serving: OpenAIServingChat,
request: BatchRequestInput) -> BatchRequestOutput:
chat_request = request.body
chat_response = await chat_serving.create_chat_completion(chat_request)
if isinstance(chat_response, ChatCompletionResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=chat_response,
error=None,
)
else:
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=None,
error=chat_response,
)
return batch_output
async def main(args):
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config()
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
args.response_role,
)
# Submit all requests in the file to the engine "concurrently".
response_futures = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
request = BatchRequestInput.model_validate_json(request_json)
response_futures.append(run_request(openai_serving_chat, request))
responses = await asyncio.gather(*response_futures)
output_buffer = StringIO()
for response in responses:
print(response.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
await write_file(args.output_file, output_buffer.read().strip())
# Temporary workaround for https://github.com/vllm-project/vllm/issues/4789
sys.exit(0)
if __name__ == "__main__":
args = parse_args()
logger.info("vLLM API server version %s", vllm.__version__)
logger.info("args: %s", args)
asyncio.run(main(args))
import asyncio
import codecs
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)
from dataclasses import dataclass
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
Optional)
from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final
from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)
from openai.types.chat import ChatCompletionContentPartTextParam
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionContentPartParam, ChatCompletionLogProb,
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
......@@ -20,6 +24,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import random_uuid
logger = init_logger(__name__)
......@@ -31,46 +36,96 @@ class ConversationMessage(TypedDict):
content: str
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
await_post_init=self._load_chat_template(
chat_template=chat_template))
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)
def _parse_chat_message_content(
self,
role: ChatCompletionRole,
content: Optional[Union[str,
Iterable[ChatCompletionContentPartParam]]],
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
if content is None:
return [], []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []
def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer
if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
def _parse_chat_message_content_parts(
self,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult:
texts: List[str] = []
for _, part in enumerate(content):
if part["type"] == "text":
text = part["text"]
for _, part in enumerate(parts):
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part['type']}")
raise NotImplementedError(f"Unknown part type: {part_type}")
messages = [ConversationMessage(role=role, content="\n".join(texts))]
return [ConversationMessage(role=role, content="\n".join(texts))], []
return ChatMessageParseResult(messages=messages)
def _parse_chat_message_content(
self,
message: ChatCompletionMessageParam,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages)
return self._parse_chat_message_content_parts(role, content)
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
self,
request: ChatCompletionRequest,
raw_request: Optional[Request] = None
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
"""Completion API similar to OpenAI's API.
......@@ -89,11 +144,10 @@ class OpenAIServingChat(OpenAIServing):
try:
conversation: List[ConversationMessage] = []
for m in request.messages:
messages, _ = self._parse_chat_message_content(
m["role"], m["content"])
for msg in request.messages:
parsed_msg = self._parse_chat_message_content(msg)
conversation.extend(messages)
conversation.extend(parsed_msg.messages)
prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
......@@ -108,7 +162,7 @@ class OpenAIServingChat(OpenAIServing):
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt)
request, prompt=prompt, add_special_tokens=False)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config()
......@@ -126,9 +180,15 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e:
return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, prompt_ids,
lora_request)
result_generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
request_id,
lora_request,
)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
......@@ -227,11 +287,10 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs:
logprobs = self._create_logprobs(
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None
......@@ -289,7 +348,7 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
self, request: ChatCompletionRequest, raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
) -> Union[ErrorResponse, ChatCompletionResponse]:
......@@ -299,7 +358,7 @@ class OpenAIServingChat(OpenAIServing):
final_res: Optional[RequestOutput] = None
async for res in result_generator:
if await raw_request.is_disconnected():
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return self.create_error_response("Client disconnected")
......@@ -314,10 +373,10 @@ class OpenAIServingChat(OpenAIServing):
top_logprobs = output.logprobs
if request.logprobs:
logprobs = self._create_logprobs(
logprobs = self._create_chat_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None
......@@ -327,8 +386,7 @@ class OpenAIServingChat(OpenAIServing):
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
stop_reason=output.stop_reason)
choices.append(choice_data)
if request.echo:
......@@ -359,34 +417,50 @@ class OpenAIServingChat(OpenAIServing):
return response
async def _load_chat_template(self, chat_template: Optional[str]):
while self.tokenizer is None:
# Give the parent class time to load the tokenizer
await asyncio.sleep(0.1)
tokenizer = self.tokenizer
if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
def _get_top_logprobs(
self, logprobs: Dict[int, Logprob],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
def _create_chat_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=step_top_logprobs[token_id].decoded_token,
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
bytes=list(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
return ChatCompletionLogProbs(content=logprobs_content)
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
......@@ -24,7 +31,7 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
......@@ -52,11 +59,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None):
lora_modules: Optional[List[LoRAModulePath]]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
......@@ -118,12 +125,17 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids,
lora_request=lora_request))
generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
f"{request_id}-{i}",
lora_request=lora_request,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......@@ -229,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
i]:] if output.logprobs else None
if request.logprobs is not None:
logprobs = self._create_logprobs(
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
......@@ -311,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
......@@ -345,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
choices=choices,
usage=usage,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: List[int] = []
out_token_logprobs: List[Optional[float]] = []
out_tokens: List[str] = []
out_top_logprobs: List[Optional[Dict[str, float]]] = []
last_token_len = 0
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0]):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
})
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
import time
from typing import AsyncIterator, List, Optional, Tuple
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> EmbeddingResponse:
data = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
embedding_data = EmbeddingResponseData(
index=idx, embedding=final_res.outputs.embedding)
data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
class OpenAIServingEmbedding(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None)
self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# Return error for unsupported features.
if request.encoding_format == "base64":
return self.create_error_response(
"base64 encoding is not currently supported")
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
try:
prompt_is_tokens, prompts = parse_prompt_format(request.input)
pooling_params = request.to_pooling_params()
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt=prompt)
prompt_ids, prompt_text = prompt_formats
generator = self.engine.encode(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
pooling_params,
f"{request_id}-{i}",
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def _check_embedding_mode(self, embedding_mode: bool):
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")
import asyncio
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse,
LogProbs, ModelCard, ModelList,
CompletionRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -29,13 +29,24 @@ class LoRAModulePath:
class OpenAIServing:
def __init__(self,
engine: AsyncLLMEngine,
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
lora_modules: Optional[List[LoRAModulePath]]):
super().__init__()
self.engine = engine
self.max_model_len = model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
truncation_side="left")
self.served_model_names = served_model_names
if lora_modules is None:
self.lora_requests = []
else:
......@@ -47,38 +58,6 @@ class OpenAIServing:
) for i, lora in enumerate(lora_modules, start=1)
]
self.max_model_len = 0
# Lazy initialized
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init(await_post_init))
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init(await_post_init))
async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
tokenizer_revision=engine_model_config.tokenizer_revision,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
if await_post_init is not None:
await await_post_init
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
......@@ -96,50 +75,6 @@ class OpenAIServing:
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None)
else:
token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
return logprobs
def create_error_response(
self,
message: str,
......@@ -163,7 +98,8 @@ class OpenAIServing:
return json_str
async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest]
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
......@@ -175,7 +111,8 @@ class OpenAIServing:
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest]
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[LoRARequest]:
if request.model in self.served_model_names:
return None
......@@ -186,12 +123,14 @@ class OpenAIServing:
raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> Tuple[List[int], str]:
self,
request: Union[ChatCompletionRequest, CompletionRequest,
EmbeddingRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
Field(ge=1)]] = None,
add_special_tokens: bool = True) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
......@@ -199,10 +138,19 @@ class OpenAIServing:
"Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None:
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
# When using OpenAIServingChat for chat completions, the
# special tokens (e.g., BOS) have already been added by the
# chat template. Therefore, we do not need to add them again.
# Set add_special_tokens to False to avoid adding the BOS tokens
# again.
tokenizer_kwargs: Dict[str, Any] = {
"add_special_tokens": add_special_tokens
}
if truncate_prompt_tokens is not None:
tokenizer_kwargs.update({
"truncation": True,
"max_length": truncate_prompt_tokens,
})
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
......@@ -213,6 +161,16 @@ class OpenAIServing:
prompt_ids)
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.", )
return input_ids, input_text
if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
......@@ -232,3 +190,8 @@ class OpenAIServing:
f"Please reduce the length of the messages or completion.", )
else:
return input_ids, input_text
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return self.tokenizer.decode(token_id)
......@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_USE_MODELSCOPE: bool = False
VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None
......@@ -21,6 +22,7 @@ if TYPE_CHECKING:
VLLM_DO_NOT_TRACK: bool = False
VLLM_USAGE_SOURCE: str = ""
VLLM_CONFIGURE_LOGGING: int = 1
VLLM_LOGGING_LEVEL: str = "INFO"
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
......@@ -96,6 +98,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
'VLLM_HOST_IP':
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
# used in distributed environment to manually set the communication port
# '0' is used to make mypy happy
'VLLM_PORT':
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE":
......@@ -145,7 +153,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID":
lambda: os.environ.get("S3_ACCESS_KEY", None),
lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
"S3_SECRET_ACCESS_KEY":
lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
"S3_ENDPOINT_URL":
......@@ -171,6 +179,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_LOGGING_CONFIG_PATH":
lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"),
# this is used for configuring the default logging level
"VLLM_LOGGING_LEVEL":
lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"),
# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
......
import asyncio
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput
logger = init_logger(__name__)
......@@ -13,6 +14,16 @@ logger = init_logger(__name__)
class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations."""
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
super().__init__(*args, **kwargs)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
......@@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor):
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
all_outputs = self._run_workers("execute_model",
driver_args=args,
driver_kwargs=kwargs)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results.
return all_outputs[0]
return self._driver_execute_model(execute_model_req)
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
......@@ -77,39 +103,95 @@ class DistributedGPUExecutor(GPUExecutor):
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_sharded_state",
path=path,
pattern=pattern,
max_size=max_size)
@abstractmethod
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
@abstractmethod
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
raise NotImplementedError
@abstractmethod
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise NotImplementedError
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())
# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
@abstractmethod
async def _run_workers_async(
async def _driver_execute_model_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
raise NotImplementedError
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
async def execute_model_async(self, *args,
**kwargs) -> List[SamplerOutput]:
all_outputs = await self._run_workers_async("execute_model",
driver_args=args,
driver_kwargs=kwargs)
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
# Only the driver worker returns the sampling results.
return all_outputs[0]
@abstractmethod
async def _start_worker_execution_loop(self):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise NotImplementedError
......@@ -74,6 +74,10 @@ class ExecutorBase(ABC):
"""Executes at least one model step on the given sequences."""
raise NotImplementedError
def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
......@@ -109,6 +113,10 @@ class ExecutorAsyncBase(ExecutorBase):
"""Executes one model step on the given sequences."""
raise NotImplementedError
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
return
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment