"vscode:/vscode.git/clone" did not exist on "07458a51ce8f31a2be0cc9da69d3e3ef6fb0f16d"
Commit 7c4f76e3 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.0

parents 2da0dd3e 51c31bc1
import copy
from collections import defaultdict
import os
import time import time
import pickle from typing import Iterable, List, Optional, Tuple, Type, Union
import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
from vllm.lora.request import LoRARequest from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig) import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader import get_architecture_class_name
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceGroup, SequenceGroupOutput, SequenceOutput,
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, SequenceStatus)
TokenizerGroup) from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import (Counter, set_cuda_visible_devices, get_ip, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_open_port, get_distributed_init_method) get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
if ray: usage_message)
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm.utils import Counter
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
...@@ -68,9 +53,10 @@ class LLMEngine: ...@@ -68,9 +53,10 @@ class LLMEngine:
parallel_config: The configuration related to distributed execution. parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device. device_config: The configuration related to the device.
placement_group: Ray placement group for distributed execution. executor_class: The model executor class for managing distributed
Required for distributed execution. execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
""" """
def __init__( def __init__(
...@@ -81,11 +67,13 @@ class LLMEngine: ...@@ -81,11 +67,13 @@ class LLMEngine:
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
placement_group: Optional["PlacementGroup"], vision_language_config: Optional["VisionLanguageConfig"],
executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine with config: " f"Initializing an LLM engine (v{vllm.__version__}) with config: "
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, " f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, " f"tokenizer_mode={model_config.tokenizer_mode}, "
...@@ -97,7 +85,8 @@ class LLMEngine: ...@@ -97,7 +85,8 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, " f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " f"disable_custom_all_reduce="
f"{parallel_config.disable_custom_all_reduce}, "
f"quantization={model_config.quantization}, " f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, " f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, " f"kv_cache_dtype={cache_config.cache_dtype}, "
...@@ -108,6 +97,7 @@ class LLMEngine: ...@@ -108,6 +97,7 @@ class LLMEngine:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
...@@ -115,22 +105,54 @@ class LLMEngine: ...@@ -115,22 +105,54 @@ class LLMEngine:
self._verify_args() self._verify_args()
self._init_tokenizer() self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. self.model_executor = executor_class(model_config, cache_config,
if self.parallel_config.worker_use_ray: parallel_config, scheduler_config,
# Disable Ray usage stats collection. device_config, lora_config,
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") vision_language_config)
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0" # If usage stat is enabled, collect relevant info.
self._init_workers_ray(placement_group) if is_usage_stats_enabled():
else: usage_message.report_usage(
self._init_workers() get_architecture_class_name(model_config),
usage_context,
# Profile the memory usage and initialize the cache. extra_kvs={
self._init_cache() # Common configuration
"dtype":
str(model_config.dtype),
"tensor_parallel_size":
parallel_config.tensor_parallel_size,
"block_size":
cache_config.block_size,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization,
# Quantization
"quantization":
model_config.quantization,
"kv_cache_dtype":
cache_config.cache_dtype,
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
model_config.enforce_eager,
"disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
})
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging. # Metric Logging.
...@@ -140,48 +162,56 @@ class LLMEngine: ...@@ -140,48 +162,56 @@ class LLMEngine:
labels=dict(model_name=model_config.model)) labels=dict(model_name=model_config.model))
self.stat_logger.info("cache_config", self.cache_config) self.stat_logger.info("cache_config", self.cache_config)
self.forward_dag = None @classmethod
if USE_RAY_COMPILED_DAG: def from_engine_args(
self.forward_dag = self._compiled_ray_dag() cls,
engine_args: EngineArgs,
def get_tokenizer_for_seq(self, sequence: Sequence): usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
return self.tokenizer.get_lora_tokenizer(sequence.lora_request) ) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
device_config = engine_configs[4]
# Initialize the cluster and specify the executor class.
if device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
else:
assert parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
def _dispatch_worker(self): # Create the LLM engine.
worker_module = DEVICE_TO_WORKER_MODULE_MAP[ engine = cls(
self.device_config.device_type] *engine_configs,
imported_worker = importlib.import_module(worker_module) executor_class=executor_class,
Worker = imported_worker.Worker log_stats=not engine_args.disable_log_stats,
return Worker usage_context=usage_context,
def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker = self._dispatch_worker()
assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = []
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
) )
self._run_workers("init_model") return engine
self._run_workers("load_model")
def __reduce__(self):
# This is to ensure that the LLMEngine is not referenced in
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs): def _init_tokenizer(self, **tokenizer_init_kwargs):
init_kwargs = dict( init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config), enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs, max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None, max_input_length=None,
...@@ -189,126 +219,8 @@ class LLMEngine: ...@@ -189,126 +219,8 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision) revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs) init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: TokenizerGroup = TokenizerGroup( self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
self.model_config.tokenizer, **init_kwargs) self.parallel_config.tokenizer_pool_config, **init_kwargs)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
num_gpus = self.cache_config.gpu_memory_utilization
else:
num_gpus = 1
self.driver_dummy_worker: RayWorkerVllm = None
self.workers: List[RayWorkerVllm] = []
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
else:
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
node_workers[driver_node_id].append(0)
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver.
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker = self._dispatch_worker()
# Initialize torch distributed process group for the workers.
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
for rank, (worker, (node_id,
_)) in enumerate(zip(self.workers,
worker_node_and_gpu_ids),
start=1):
local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote(
lambda rank=rank, local_rank=local_rank: Worker(
model_config,
parallel_config,
scheduler_config,
device_config,
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
))
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
model_config,
parallel_config,
scheduler_config,
device_config,
driver_local_rank,
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
# don't use cupy for eager mode
self._run_workers("init_model",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
...@@ -318,81 +230,6 @@ class LLMEngine: ...@@ -318,81 +230,6 @@ class LLMEngine:
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameters.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log.
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_gpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self._run_workers("warm_up_model")
@classmethod
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
placement_group = initialize_cluster(parallel_config)
# Create the LLM engine.
engine = cls(*engine_configs,
placement_group,
log_stats=not engine_args.disable_log_stats)
return engine
def encode_request( def encode_request(
self, self,
request_id: str, # pylint: disable=unused-argument request_id: str, # pylint: disable=unused-argument
...@@ -415,7 +252,7 @@ class LLMEngine: ...@@ -415,7 +252,7 @@ class LLMEngine:
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -432,11 +269,7 @@ class LLMEngine: ...@@ -432,11 +269,7 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
prefix_pos: If not None, we use the given position as the prefix multi_modal_data: Multi modal data per request.
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
...@@ -465,8 +298,15 @@ class LLMEngine: ...@@ -465,8 +298,15 @@ class LLMEngine:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
if arrival_time is None: if arrival_time is None:
arrival_time = time.monotonic() arrival_time = time.time()
prompt_token_ids = self.encode_request( prompt_token_ids = self.encode_request(
request_id=request_id, request_id=request_id,
prompt=prompt, prompt=prompt,
...@@ -476,21 +316,21 @@ class LLMEngine: ...@@ -476,21 +316,21 @@ class LLMEngine:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
lora_request) eos_token_id, lora_request)
# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone() sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request, prefix) arrival_time, lora_request, multi_modal_data)
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group) self.scheduler.add_seq_group(seq_group)
...@@ -538,15 +378,13 @@ class LLMEngine: ...@@ -538,15 +378,13 @@ class LLMEngine:
if early_stopping is True: if early_stopping is True:
return True return True
current_worst_score = (current_worst_seq.get_beam_search_score( current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.get_tokenizer_for_seq( eos_token_id=current_worst_seq.eos_token_id)
current_worst_seq).eos_token_id))
if early_stopping is False: if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score( highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.get_tokenizer_for_seq( eos_token_id=best_running_seq.eos_token_id)
best_running_seq).eos_token_id))
else: else:
assert early_stopping == "never" assert early_stopping == "never"
if length_penalty > 0.0: if length_penalty > 0.0:
...@@ -560,8 +398,7 @@ class LLMEngine: ...@@ -560,8 +398,7 @@ class LLMEngine:
highest_attainable_score = ( highest_attainable_score = (
best_running_seq.get_beam_search_score( best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.get_tokenizer_for_seq( eos_token_id=best_running_seq.eos_token_id,
best_running_seq).eos_token_id,
seq_len=max_possible_length)) seq_len=max_possible_length))
else: else:
# Otherwise, beam search will prefer shorter sequences. The # Otherwise, beam search will prefer shorter sequences. The
...@@ -570,8 +407,7 @@ class LLMEngine: ...@@ -570,8 +407,7 @@ class LLMEngine:
highest_attainable_score = ( highest_attainable_score = (
best_running_seq.get_beam_search_score( best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.get_tokenizer_for_seq( eos_token_id=best_running_seq.eos_token_id))
best_running_seq).eos_token_id))
return current_worst_score >= highest_attainable_score return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
...@@ -580,6 +416,8 @@ class LLMEngine: ...@@ -580,6 +416,8 @@ class LLMEngine:
# Process prompt logprobs # Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None: if prompt_logprobs is not None:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs seq_group.prompt_logprobs = prompt_logprobs
# Process samples # Process samples
...@@ -623,7 +461,8 @@ class LLMEngine: ...@@ -623,7 +461,8 @@ class LLMEngine:
child_seqs.append((parent, parent)) child_seqs.append((parent, parent))
for seq, _ in child_seqs: for seq, _ in child_seqs:
self._decode_sequence(seq, seq_group.sampling_params) self.detokenizer.decode_sequence_inplace(seq,
seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case # Non-beam search case
...@@ -662,8 +501,7 @@ class LLMEngine: ...@@ -662,8 +501,7 @@ class LLMEngine:
all_finished_seqs = existing_finished_seqs + new_finished_seqs all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores. # Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True) reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]: for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new: if is_new:
...@@ -690,8 +528,7 @@ class LLMEngine: ...@@ -690,8 +528,7 @@ class LLMEngine:
if not seq.is_finished()] if not seq.is_finished()]
# Sort the running sequences by their scores. # Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True) reverse=True)
# Check if we can stop the beam search. # Check if we can stop the beam search.
...@@ -752,7 +589,11 @@ class LLMEngine: ...@@ -752,7 +589,11 @@ class LLMEngine:
now = time.time() now = time.time()
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, outputs in zip(scheduled_seq_groups, output):
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.update_num_computed_tokens(token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs) self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
...@@ -760,7 +601,8 @@ class LLMEngine: ...@@ -760,7 +601,8 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups: for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
...@@ -768,16 +610,9 @@ class LLMEngine: ...@@ -768,16 +610,9 @@ class LLMEngine:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
# Update prefix state, now all the uncomputed prefixes are computed.
for seq_group in scheduled_seq_groups:
if (seq_group.prefix is not None and seq_group.prefix.allocated
and not seq_group.prefix.computed):
seq_group.prefix.computed = True
# Log stats. # Log stats.
if self.log_stats: if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs)) self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs return request_outputs
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
...@@ -798,7 +633,7 @@ class LLMEngine: ...@@ -798,7 +633,7 @@ class LLMEngine:
- A Sequence Group (SG) refer to a group of sequences - A Sequence Group (SG) refer to a group of sequences
that are generated from the same prompt. that are generated from the same prompt.
- Step 2: Calls the workers to execute the model. - Step 2: Calls the distributed executor to execute the model.
- Step 3: Processes the model output. This mainly includes: - Step 3: Processes the model output. This mainly includes:
- Decodes the relevant outputs. - Decodes the relevant outputs.
...@@ -834,19 +669,10 @@ class LLMEngine: ...@@ -834,19 +669,10 @@ class LLMEngine:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. output = self.model_executor.execute_model(
all_outputs = self._run_workers( seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
"execute_model", scheduler_outputs.blocks_to_swap_out,
driver_kwargs={ scheduler_outputs.blocks_to_copy)
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results.
output = all_outputs[0]
else: else:
output = [] output = []
...@@ -860,7 +686,7 @@ class LLMEngine: ...@@ -860,7 +686,7 @@ class LLMEngine:
def _get_stats(self, def _get_stats(self,
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
"""Get Stats to be Logged to Prometheus.""" """Get Stats to be Logged to Prometheus."""
now = time.monotonic() now = time.time()
# KV Cache Usage in %. # KV Cache Usage in %.
num_total_gpu = self.cache_config.num_gpu_blocks num_total_gpu = self.cache_config.num_gpu_blocks
...@@ -891,18 +717,22 @@ class LLMEngine: ...@@ -891,18 +717,22 @@ class LLMEngine:
# Number of Tokens. # Number of Tokens.
if prompt_run: if prompt_run:
num_prompt_tokens = sum( num_prompt_tokens = sum(
len(seq_group.prompt_token_ids) len(scheduled_seq_group.seq_group.prompt_token_ids)
for seq_group in scheduler_outputs.scheduled_seq_groups) for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
num_generation_tokens = sum( num_generation_tokens = sum(
seq_group.num_seqs() scheduled_seq_group.seq_group.num_seqs()
for seq_group in scheduler_outputs.scheduled_seq_groups) for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
else: else:
num_generation_tokens = scheduler_outputs.num_batched_tokens num_generation_tokens = scheduler_outputs.num_batched_tokens
# Latency Timings. # Latency Timings.
time_last_iters = [] time_last_iters = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
# Time since last token. (n.b. updates seq_group.metrics.last_token_time) seq_group = scheduled_seq_group.seq_group
# Time since last token.
# (n.b. updates seq_group.metrics.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now)) time_last_iters.append(seq_group.get_last_latency(now))
# Time since arrival for all finished requests. # Time since arrival for all finished requests.
if seq_group.is_finished(): if seq_group.is_finished():
...@@ -926,41 +756,9 @@ class LLMEngine: ...@@ -926,41 +756,9 @@ class LLMEngine:
time_e2e_requests=time_e2e_requests, time_e2e_requests=time_e2e_requests,
) )
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(),
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_output_text
def _check_stop(self, seq: Sequence, def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:
"""Stop the finished sequences.""" """Stop the finished sequences."""
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
return
if seq.get_last_token_id() in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
seq.get_last_token_id())
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if the sequence has reached max_model_len. # Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len: if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
...@@ -971,9 +769,29 @@ class LLMEngine: ...@@ -971,9 +769,29 @@ class LLMEngine:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
last_token_id)
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos) and seq.get_last_token_id() if ((not sampling_params.ignore_eos)
== self.get_tokenizer_for_seq(seq).eos_token_id): and seq.get_last_token_id() == seq.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
...@@ -989,91 +807,13 @@ class LLMEngine: ...@@ -989,91 +807,13 @@ class LLMEngine:
seq.output_text = seq.output_text[:-len(stop_string)] seq.output_text = seq.output_text[:-len(stop_string)]
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self.model_executor.add_lora(lora_request)
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0." return self.model_executor.remove_lora(lora_id)
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> List[int]: def list_loras(self) -> List[int]:
return self._run_workers("list_loras") return self.model_executor.list_loras()
def _run_workers( def check_health(self) -> None:
self, self.model_executor.check_health()
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import MultiOutputNode, InputNode
assert self.parallel_config.worker_use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.bind(input_data)
for worker in self.workers
])
return forward_dag.experimental_compile()
from vllm.logger import init_logger
from prometheus_client import Counter, Gauge, Histogram, Info, REGISTRY, disable_created_metrics
import time import time
import numpy as np
from typing import Dict, List
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
disable_created_metrics)
from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -23,6 +25,7 @@ class Metrics: ...@@ -23,6 +25,7 @@ class Metrics:
if hasattr(collector, "_name") and "vllm" in collector._name: if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector) REGISTRY.unregister(collector)
# Config Information
self.info_cache_config = Info( self.info_cache_config = Info(
name='vllm:cache_config', name='vllm:cache_config',
documentation='information of cache_config') documentation='information of cache_config')
...@@ -176,10 +179,12 @@ class StatLogger: ...@@ -176,10 +179,12 @@ class StatLogger:
def _log_prometheus_interval(self, prompt_throughput: float, def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None: generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval. # Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on the vLLM side. # Support legacy gauge metrics that make throughput calculations on
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens # the vLLM side. Moving forward, we should use counters like
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side. # counter_prompt_tokens, counter_generation_tokens
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666 # Which log raw data and calculate summaries using rate() on the
# grafana/prometheus side. See
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
self.metrics.gauge_avg_prompt_throughput.labels( self.metrics.gauge_avg_prompt_throughput.labels(
**self.labels).set(prompt_throughput) **self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels( self.metrics.gauge_avg_generation_throughput.labels(
...@@ -187,7 +192,7 @@ class StatLogger: ...@@ -187,7 +192,7 @@ class StatLogger:
def log(self, stats: Stats) -> None: def log(self, stats: Stats) -> None:
"""Called by LLMEngine. """Called by LLMEngine.
Logs to prometheus and tracked stats every iteration. Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds.""" Logs to Stdout every self.local_interval seconds."""
# Log to prometheus. # Log to prometheus.
...@@ -199,8 +204,8 @@ class StatLogger: ...@@ -199,8 +204,8 @@ class StatLogger:
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now): if self._local_interval_elapsed(stats.now):
# Compute summary metrics for tracked stats (and log them
# Compute summary metrics for tracked stats (and log them to promethus if applicable). # to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens, prompt_throughput = self._get_throughput(self.num_prompt_tokens,
now=stats.now) now=stats.now)
generation_throughput = self._get_throughput( generation_throughput = self._get_throughput(
...@@ -212,7 +217,8 @@ class StatLogger: ...@@ -212,7 +217,8 @@ class StatLogger:
# Log to stdout. # Log to stdout.
logger.info( logger.info(
f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, " f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
f"Avg generation throughput: {generation_throughput:.1f} tokens/s, " f"Avg generation throughput: "
f"{generation_throughput:.1f} tokens/s, "
f"Running: {stats.num_running} reqs, " f"Running: {stats.num_running} reqs, "
f"Swapped: {stats.num_swapped} reqs, " f"Swapped: {stats.num_swapped} reqs, "
f"Pending: {stats.num_waiting} reqs, " f"Pending: {stats.num_waiting} reqs, "
......
import pickle import pickle
from typing import List, Optional, Tuple
from typing import Optional, List, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip, set_cuda_visible_devices, get_ip from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,8 +32,17 @@ try: ...@@ -33,8 +32,17 @@ try:
return getattr(self.worker, name) return getattr(self.worker, name)
def execute_method(self, method, *args, **kwargs): def execute_method(self, method, *args, **kwargs):
executor = getattr(self, method) try:
return executor(*args, **kwargs) executor = getattr(self, method)
return executor(*args, **kwargs)
except Exception as e:
# exceptions in ray worker may cause deadlock
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
...@@ -65,45 +73,38 @@ except ImportError as e: ...@@ -65,45 +73,38 @@ except ImportError as e:
ray = None ray = None
RayWorkerVllm = None RayWorkerVllm = None
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
def initialize_cluster( def initialize_ray_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None, ray_address: Optional[str] = None,
) -> Optional["PlacementGroup"]: ):
"""Initialize the distributed cluster probably with Ray. """Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args: Args:
parallel_config: The configurations for parallel execution. parallel_config: The configurations for parallel execution.
engine_use_ray: Whether to use Ray for async engine.
ray_address: The address of the Ray cluster. If None, uses ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address. the default Ray cluster address.
Returns:
An optional `PlacementGroup`. It includes the specification
of the resources for each distributed worker. None if Ray is
not used.
""" """
if parallel_config.worker_use_ray or engine_use_ray: if ray is None:
if ray is None: raise ImportError(
raise ImportError( "Ray is not installed. Please install Ray to use distributed "
"Ray is not installed. Please install Ray to use distributed " "serving.")
"serving.")
# Connect to a ray cluster. # Connect to a ray cluster.
if is_hip(): if is_hip():
ray.init(address=ray_address, ray.init(address=ray_address,
ignore_reinit_error=True, ignore_reinit_error=True,
num_gpus=parallel_config.world_size) num_gpus=parallel_config.world_size)
else: else:
ray.init(address=ray_address, ignore_reinit_error=True) ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray: if parallel_config.placement_group:
assert parallel_config.world_size == 1, ( # Placement group is already set.
"Ray is required if parallel_config.world_size > 1.") return
return None
# Create placement group for worker processes # Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group() current_placement_group = ray.util.get_current_placement_group()
...@@ -138,4 +139,5 @@ def initialize_cluster( ...@@ -138,4 +139,5 @@ def initialize_cluster(
# if they cannot be provisioned. # if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800) ray.get(current_placement_group.ready(), timeout=1800)
return current_placement_group # Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group
""" """
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks. NOTE: This API server is used only for demonstrating usage of AsyncEngine
It is not intended for production use. For production use, we recommend using our OpenAI compatible server. and simple performance benchmarks. It is not intended for production use.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
""" """
import argparse import argparse
import json import json
import ssl
from typing import AsyncGenerator from typing import AsyncGenerator
import uvicorn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
...@@ -39,15 +43,11 @@ async def generate(request: Request) -> Response: ...@@ -39,15 +43,11 @@ async def generate(request: Request) -> Response:
""" """
request_dict = await request.json() request_dict = await request.json()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
results_generator = engine.generate(prompt, results_generator = engine.generate(prompt, sampling_params, request_id)
sampling_params,
request_id,
prefix_pos=prefix_pos)
# Streaming case # Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
...@@ -84,6 +84,16 @@ if __name__ == "__main__": ...@@ -84,6 +84,16 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument( parser.add_argument(
"--root-path", "--root-path",
type=str, type=str,
...@@ -91,9 +101,9 @@ if __name__ == "__main__": ...@@ -91,9 +101,9 @@ if __name__ == "__main__":
help="FastAPI root_path when app is behind a path based routing proxy") help="FastAPI root_path when app is behind a path based routing proxy")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER)
app.root_path = args.root_path app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
...@@ -102,4 +112,6 @@ if __name__ == "__main__": ...@@ -102,4 +112,6 @@ if __name__ == "__main__":
log_level="debug", log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE, timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile, ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile) ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
from typing import List, Optional, Union from typing import List, Optional, Union
import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.lora.request import LoRARequest
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter from vllm.utils import Counter
...@@ -83,7 +86,7 @@ class LLM: ...@@ -83,7 +86,7 @@ class LLM:
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: int = 8192, max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
...@@ -106,7 +109,8 @@ class LLM: ...@@ -106,7 +109,8 @@ class LLM:
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args(engine_args) self.llm_engine = LLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(
...@@ -124,9 +128,9 @@ class LLM: ...@@ -124,9 +128,9 @@ class LLM:
prompts: Optional[Union[str, List[str]]] = None, prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -140,13 +144,9 @@ class LLM: ...@@ -140,13 +144,9 @@ class LLM:
None, we use the default sampling parameters. None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns: Returns:
A list of `RequestOutput` objects containing the generated A list of `RequestOutput` objects containing the generated
...@@ -166,19 +166,27 @@ class LLM: ...@@ -166,19 +166,27 @@ class LLM:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len( num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids) prompt_token_ids)
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[
i] i]
self._add_request(prompt, self._add_request(
sampling_params, prompt,
token_ids, sampling_params,
lora_request=lora_request, token_ids,
prefix_pos=prefix_pos_i) lora_request=lora_request,
# Get ith image while maintaining the batch dim.
multi_modal_data=MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0))
if multi_modal_data else None,
)
return self._run_engine(use_tqdm) return self._run_engine(use_tqdm)
def _add_request( def _add_request(
...@@ -187,7 +195,7 @@ class LLM: ...@@ -187,7 +195,7 @@ class LLM:
sampling_params: SamplingParams, sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, self.llm_engine.add_request(request_id,
...@@ -195,13 +203,15 @@ class LLM: ...@@ -195,13 +203,15 @@ class LLM:
sampling_params, sampling_params,
prompt_token_ids, prompt_token_ids,
lora_request=lora_request, lora_request=lora_request,
prefix_pos=prefix_pos) multi_modal_data=multi_modal_data)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests() num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts") pbar = tqdm(total=num_requests,
desc="Processed prompts",
dynamic_ncols=True)
# Run the engine. # Run the engine.
outputs: List[RequestOutput] = [] outputs: List[RequestOutput] = []
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
......
import argparse
import asyncio import asyncio
import json
from contextlib import asynccontextmanager
import os
import importlib import importlib
import inspect import inspect
import os
from contextlib import asynccontextmanager
from http import HTTPStatus
from prometheus_client import make_asgi_app
import fastapi import fastapi
import uvicorn import uvicorn
from http import HTTPStatus
from fastapi import Request from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse, Response from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.logger import init_logger from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
...@@ -47,95 +48,8 @@ async def lifespan(app: fastapi.FastAPI): ...@@ -47,95 +48,8 @@ async def lifespan(app: fastapi.FastAPI):
app = fastapi.FastAPI(lifespan=lifespan) app = fastapi.FastAPI(lifespan=lifespan)
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = make_arg_parser()
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument(
"--api-key",
type=str,
default=None,
help=
"If provided, the server will require this key to be presented in the header."
)
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser.add_argument("--chat-template",
type=str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)
parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args() return parser.parse_args()
...@@ -153,6 +67,7 @@ async def validation_exception_handler(_, exc): ...@@ -153,6 +67,7 @@ async def validation_exception_handler(_, exc):
@app.get("/health") @app.get("/health")
async def health() -> Response: async def health() -> Response:
"""Health check.""" """Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200) return Response(status_code=200)
...@@ -162,6 +77,12 @@ async def show_available_models(): ...@@ -162,6 +77,12 @@ async def show_available_models():
return JSONResponse(content=models.model_dump()) return JSONResponse(content=models.model_dump())
@app.get("/version")
async def show_version():
ver = {"version": vllm.__version__}
return JSONResponse(content=ver)
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
...@@ -221,19 +142,19 @@ if __name__ == "__main__": ...@@ -221,19 +142,19 @@ if __name__ == "__main__":
elif inspect.iscoroutinefunction(imported): elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported) app.middleware("http")(imported)
else: else:
raise ValueError( raise ValueError(f"Invalid middleware {middleware}. "
f"Invalid middleware {middleware}. Must be a function or a class." f"Must be a function or a class.")
)
logger.info(f"vLLM API server version {vllm.__version__}")
logger.info(f"args: {args}") logger.info(f"args: {args}")
if args.served_model_name is not None: if args.served_model_name is not None:
served_model = args.served_model_name served_model = args.served_model_name
else: else:
served_model = args.model served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model, openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role, args.response_role,
args.lora_modules, args.lora_modules,
...@@ -245,7 +166,9 @@ if __name__ == "__main__": ...@@ -245,7 +166,9 @@ if __name__ == "__main__":
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
port=args.port, port=args.port,
log_level="info", log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE, timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile, ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile) ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import argparse
import json
import ssl
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.serving_engine import LoRA
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def make_arg_parser():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument("--chat-template",
type=str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser
...@@ -3,12 +3,11 @@ ...@@ -3,12 +3,11 @@
import time import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
import torch
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
...@@ -55,40 +54,87 @@ class UsageInfo(BaseModel): ...@@ -55,40 +54,87 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: str = Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[Dict[str, str]] messages: List[Dict[str, str]]
temperature: Optional[float] = 0.7 model: str
top_p: Optional[float] = 1.0 frequency_penalty: Optional[float] = 0.0
n: Optional[int] = 1 logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[bool] = False temperature: Optional[float] = 0.7
top_logprobs: Optional[int] = None top_p: Optional[float] = 1.0
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None user: Optional[str] = None
# Additional parameters supported by vLLM
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
add_generation_prompt: Optional[bool] = True # doc: end-chat-completion-sampling-params
echo: Optional[bool] = False
repetition_penalty: Optional[float] = 1.0 # doc: begin-chat-completion-extra-params
min_p: Optional[float] = 0.0 echo: Optional[bool] = Field(
include_stop_str_in_output: Optional[bool] = False default=False,
length_penalty: Optional[float] = 1.0 description=(
guided_json: Optional[Union[str, dict, BaseModel]] = None "If true, the new message will be prepended with the last message "
guided_regex: Optional[str] = None "if they belong to the same role."),
guided_choice: Optional[List[str]] = None )
add_generation_prompt: Optional[bool] = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
# doc: end-chat-completion-extra-params
def to_sampling_params(self) -> SamplingParams: def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs: if self.logprobs and not self.top_logprobs:
...@@ -120,6 +166,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -120,6 +166,7 @@ class ChatCompletionRequest(BaseModel):
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None, logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None, prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of, best_of=self.best_of,
...@@ -150,39 +197,75 @@ class ChatCompletionRequest(BaseModel): ...@@ -150,39 +197,75 @@ class ChatCompletionRequest(BaseModel):
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str model: str
# a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]] prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None best_of: Optional[int] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
seed: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = None user: Optional[str] = None
# Additional parameters supported by vLLM
top_k: Optional[int] = -1 # doc: begin-completion-sampling-params
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
repetition_penalty: Optional[float] = 1.0 # doc: end-completion-sampling-params
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False # doc: begin-completion-extra-params
length_penalty: Optional[float] = 1.0 include_stop_str_in_output: Optional[bool] = Field(
guided_json: Optional[Union[str, dict, BaseModel]] = None default=False,
guided_regex: Optional[str] = None description=(
guided_choice: Optional[List[str]] = None "Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
# doc: end-completion-extra-params
def to_sampling_params(self): def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0
...@@ -216,6 +299,7 @@ class CompletionRequest(BaseModel): ...@@ -216,6 +299,7 @@ class CompletionRequest(BaseModel):
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1, max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
logprobs=self.logprobs, logprobs=self.logprobs,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,
...@@ -246,7 +330,7 @@ class LogProbs(BaseModel): ...@@ -246,7 +330,7 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
...@@ -254,6 +338,13 @@ class CompletionResponseChoice(BaseModel): ...@@ -254,6 +338,13 @@ class CompletionResponseChoice(BaseModel):
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
...@@ -270,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel): ...@@ -270,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel):
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionStreamResponse(BaseModel): class CompletionStreamResponse(BaseModel):
...@@ -291,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel): ...@@ -291,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel):
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
...@@ -312,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): ...@@ -312,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel):
......
import time
import codecs import codecs
import time
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
from fastapi import Request from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo) UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA from vllm.utils import random_uuid
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,8 +40,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -37,8 +40,9 @@ class OpenAIServingChat(OpenAIServing):
ChatCompletionResponse]: ChatCompletionResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API. for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature: NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves) - function_call (Users should implement this by themselves)
...@@ -65,7 +69,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -65,7 +69,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = ( guided_decode_logits_processor = (
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer())) request, await self.engine.get_tokenizer()))
if guided_decode_logits_processor: if guided_decode_logits_processor:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
...@@ -82,8 +86,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -82,8 +86,12 @@ class OpenAIServingChat(OpenAIServing):
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, result_generator, request_id) request, result_generator, request_id)
else: else:
return await self.chat_completion_full_generator( try:
request, raw_request, result_generator, request_id) return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str: def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt: if request.add_generation_prompt:
...@@ -97,119 +105,139 @@ class OpenAIServingChat(OpenAIServing): ...@@ -97,119 +105,139 @@ class OpenAIServingChat(OpenAIServing):
) -> Union[ErrorResponse, AsyncGenerator[str, None]]: ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
model_name = request.model model_name = request.model
created_time = int(time.monotonic()) created_time = int(time.time())
chunk_object_type = "chat.completion.chunk" chunk_object_type = "chat.completion.chunk"
first_iteration = True
# Send first response for each request.n (index) with the role
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=last_msg_content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
previous_texts = [""] * request.n previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n finish_reason_sent = [False] * request.n
async for res in result_generator: try:
res: RequestOutput async for res in result_generator:
for output in res.outputs: res: RequestOutput
i = output.index # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
if finish_reason_sent[i]: # response (by the try...catch).
continue if first_iteration:
# Send first response for each request.n (index) with
delta_token_ids = output.token_ids[previous_num_tokens[i]:] # the role
top_logprobs = output.logprobs[ role = self.get_chat_request_role(request)
previous_num_tokens[i]:] if output.logprobs else None for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
if request.logprobs: index=i,
logprobs = self._create_logprobs( delta=DeltaMessage(role=role),
token_ids=delta_token_ids, logprobs=None,
top_logprobs=top_logprobs, finish_reason=None)
num_output_top_logprobs=request.logprobs, chunk = ChatCompletionStreamResponse(
initial_text_offset=len(previous_texts[i]), id=request_id,
) object=chunk_object_type,
else: created=created_time,
logprobs = None choices=[choice_data],
model=model_name)
delta_text = output.text[len(previous_texts[i]):] data = chunk.model_dump_json(exclude_unset=True)
previous_texts[i] = output.text yield f"data: {data}\n\n"
previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None: # Send response to echo the input portion of the
# Send token-by-token response for each request.n # last message
choice_data = ChatCompletionResponseStreamChoice( if request.echo:
index=i, last_msg_content = ""
delta=DeltaMessage(content=delta_text), if request.messages and isinstance(
logprobs=logprobs, request.messages,
finish_reason=None) list) and request.messages[-1].get(
chunk = ChatCompletionStreamResponse( "content") and request.messages[-1].get(
id=request_id, "role") == role:
object=chunk_object_type, last_msg_content = request.messages[-1]["content"]
created=created_time,
choices=[choice_data], if last_msg_content:
model=model_name) for i in range(request.n):
data = chunk.model_dump_json(exclude_unset=True) choice_data = (
yield f"data: {data}\n\n" ChatCompletionResponseStreamChoice(
else: index=i,
# Send the finish response for each request.n only once delta=DeltaMessage(
prompt_tokens = len(res.prompt_token_ids) content=last_msg_content),
final_usage = UsageInfo( finish_reason=None))
prompt_tokens=prompt_tokens, chunk = ChatCompletionStreamResponse(
completion_tokens=previous_num_tokens[i], id=request_id,
total_tokens=prompt_tokens + previous_num_tokens[i], object=chunk_object_type,
) created=created_time,
choice_data = ChatCompletionResponseStreamChoice( choices=[choice_data],
index=i, logprobs=None,
delta=DeltaMessage(content=delta_text), model=model_name)
logprobs=logprobs, data = chunk.model_dump_json(
finish_reason=output.finish_reason) exclude_unset=True)
chunk = ChatCompletionStreamResponse( yield f"data: {data}\n\n"
id=request_id, first_iteration = False
object=chunk_object_type,
created=created_time, for output in res.outputs:
choices=[choice_data], i = output.index
model=model_name)
if final_usage is not None: if finish_reason_sent[i]:
chunk.usage = final_usage continue
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True) delta_token_ids = output.token_ids[previous_num_tokens[i]:]
yield f"data: {data}\n\n" top_logprobs = output.logprobs[
finish_reason_sent[i] = True previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished # Send the final done message after all response.n are finished
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
...@@ -219,7 +247,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -219,7 +247,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = request.model model_name = request.model
created_time = int(time.monotonic()) created_time = int(time.time())
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
...@@ -251,6 +279,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -251,6 +279,7 @@ class OpenAIServingChat(OpenAIServing):
message=ChatMessage(role=role, content=output.text), message=ChatMessage(role=role, content=output.text),
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
) )
choices.append(choice_data) choices.append(choice_data)
......
import asyncio import asyncio
import time import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
from fastapi import Request from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionRequest, CompletionResponse,
CompletionResponse, CompletionResponseChoice,
CompletionResponseChoice, CompletionResponseStreamChoice,
CompletionResponseStreamChoice, CompletionStreamResponse,
CompletionStreamResponse, LogProbs, UsageInfo)
LogProbs, from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
UsageInfo, from vllm.logger import init_logger
) from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA from vllm.utils import random_uuid
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -26,107 +27,6 @@ TypeCreateLogProbsFn = Callable[ ...@@ -26,107 +27,6 @@ TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
async def completion_stream_generator(
request: CompletionRequest,
raw_request: Request,
on_abort,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await on_abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
]).model_dump_json()
yield f"data: {response_json}\n\n"
if output.finish_reason is not None: # return final usage
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
],
usage=final_usage,
).model_dump_json()
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
def parse_prompt_format(prompt) -> Tuple[bool, list]: def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following # get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays." # "a string, array of strings, array of tokens, or array of token arrays."
...@@ -145,79 +45,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: ...@@ -145,79 +45,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
prompt_is_tokens = True prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays prompts = prompt # case 4: array of token arrays
else: else:
raise ValueError( raise ValueError("prompt must be a string, array of strings, "
"prompt must be a string, array of strings, array of tokens, or array of token arrays" "array of tokens, or array of token arrays")
)
return prompt_is_tokens, prompts return prompt_is_tokens, prompts
def request_output_to_completion_response(
final_res_batch: List[RequestOutput],
request: CompletionRequest,
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
) -> CompletionResponse:
choices = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
def merge_async_iterators(*iterators): def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator. """Merge multiple asynchronous iterators into a single iterator.
...@@ -230,8 +62,11 @@ def merge_async_iterators(*iterators): ...@@ -230,8 +62,11 @@ def merge_async_iterators(*iterators):
finished = [False] * len(iterators) finished = [False] * len(iterators)
async def producer(i, iterator): async def producer(i, iterator):
async for item in iterator: try:
await queue.put((i, item)) async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True finished[i] = True
_tasks = [ _tasks = [
...@@ -242,6 +77,8 @@ def merge_async_iterators(*iterators): ...@@ -242,6 +77,8 @@ def merge_async_iterators(*iterators):
async def consumer(): async def consumer():
while not all(finished) or not queue.empty(): while not all(finished) or not queue.empty():
item = await queue.get() item = await queue.get()
if isinstance(item, Exception):
raise item
yield item yield item
await asyncio.gather(*_tasks) await asyncio.gather(*_tasks)
...@@ -280,7 +117,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -280,7 +117,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name = request.model model_name = request.model
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic()) created_time = int(time.time())
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators = [] generators = []
...@@ -289,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -289,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = ( guided_decode_logit_processor = (
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer())) request, await self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None: if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
...@@ -312,40 +149,43 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -312,40 +149,43 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_token_ids=input_ids, prompt_token_ids=input_ids,
lora_request=lora_request)) lora_request=lora_request))
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[ result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators) int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use
# beam search.
stream = (request.stream stream = (request.stream
and (request.best_of is None or request.n == request.best_of) and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search) and not request.use_beam_search)
# Streaming response # Streaming response
if stream: if stream:
return completion_stream_generator(request, return self.completion_stream_generator(request,
raw_request, raw_request,
self.engine.abort, result_generator,
result_generator, request_id,
self._create_logprobs, created_time,
request_id, model_name,
created_time, num_prompts=len(prompts))
model_name,
num_prompts=len(prompts))
# Non-streaming response # Non-streaming response
final_res_batch: RequestOutput = [None] * len(prompts) final_res_batch: RequestOutput = [None] * len(prompts)
async for i, res in result_generator: try:
if await raw_request.is_disconnected(): async for i, res in result_generator:
# Abort the request if the client disconnects. if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{i}") # Abort the request if the client disconnects.
return self.create_error_response("Client disconnected") await self.engine.abort(f"{request_id}-{i}")
final_res_batch[i] = res return self.create_error_response("Client disconnected")
response = request_output_to_completion_response( final_res_batch[i] = res
final_res_batch, request, self._create_logprobs, request_id, response = self.request_output_to_completion_response(
created_time, model_name) final_res_batch, request, request_id, created_time, model_name)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
...@@ -359,3 +199,166 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -359,3 +199,166 @@ class OpenAIServingCompletion(OpenAIServing):
return fake_stream_generator() return fake_stream_generator()
return response return response
async def completion_stream_generator(
self,
request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
try:
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
top_logprobs = res.prompt_logprobs + (output.logprobs
or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
top_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
if request.logprobs is not None:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
if output.finish_reason is not None: # return final usage
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
final_usage = None
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
)
],
usage=final_usage,
).model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: List[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
) -> CompletionResponse:
choices = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
import asyncio import asyncio
import json
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionRequest, CompletionRequest, ErrorResponse,
ErrorResponse, LogProbs, LogProbs, ModelCard, ModelList,
ModelCard, ModelList,
ModelPermission) ModelPermission)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -48,10 +50,12 @@ class OpenAIServing: ...@@ -48,10 +50,12 @@ class OpenAIServing:
except RuntimeError: except RuntimeError:
event_loop = None event_loop = None
if event_loop is not None and event_loop.is_running( if event_loop is not None and event_loop.is_running():
): # If the current is instanced by Ray Serve, there is already a running event loop # If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init()) event_loop.create_task(self._post_init())
else: # When using single vLLM without engine_use_ray else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init()) asyncio.run(self._post_init())
async def _post_init(self): async def _post_init(self):
...@@ -83,7 +87,7 @@ class OpenAIServing: ...@@ -83,7 +87,7 @@ class OpenAIServing:
def _create_logprobs( def _create_logprobs(
self, self,
token_ids: List[int], token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0, initial_text_offset: int = 0,
) -> LogProbs: ) -> LogProbs:
...@@ -95,10 +99,10 @@ class OpenAIServing: ...@@ -95,10 +99,10 @@ class OpenAIServing:
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None: if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id] token_logprob = step_top_logprobs[token_id].logprob
else: else:
token_logprob = None token_logprob = None
token = self.tokenizer.convert_ids_to_tokens(token_id) token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token) logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob) logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
...@@ -110,7 +114,7 @@ class OpenAIServing: ...@@ -110,7 +114,7 @@ class OpenAIServing:
if num_output_top_logprobs: if num_output_top_logprobs:
logprobs.top_logprobs.append({ logprobs.top_logprobs.append({
self.tokenizer.convert_ids_to_tokens(i): p p.decoded_token: p.logprob
for i, p in step_top_logprobs.items() for i, p in step_top_logprobs.items()
} if step_top_logprobs else None) } if step_top_logprobs else None)
return logprobs return logprobs
...@@ -124,6 +128,19 @@ class OpenAIServing: ...@@ -124,6 +128,19 @@ class OpenAIServing:
type=err_type, type=err_type,
code=status_code.value) code=status_code.value)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model: if request.model == self.served_model:
return return
...@@ -163,8 +180,9 @@ class OpenAIServing: ...@@ -163,8 +180,9 @@ class OpenAIServing:
if token_num + request.max_tokens > self.max_model_len: if token_num + request.max_tokens > self.max_model_len:
raise ValueError( raise ValueError(
f"This model's maximum context length is {self.max_model_len} tokens. " f"This model's maximum context length is "
f"However, you requested {request.max_tokens + token_num} tokens " f"{self.max_model_len} tokens. However, you requested "
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, " f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). " f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", ) f"Please reduce the length of the messages or completion.", )
......
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
class ExecutorBase(ABC):
"""Base class for all executors.
An executor is responsible for executing the model on a specific device
type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
that can execute the model on multiple devices.
"""
@abstractmethod
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
raise NotImplementedError
@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
"""Executes one model step on the given sequences."""
raise NotImplementedError
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def list_loras(self) -> List[int]:
raise NotImplementedError
@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
class ExecutorAsyncBase(ExecutorBase):
@abstractmethod
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
"""Executes one model step on the given sequences."""
raise NotImplementedError
@abstractmethod
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
from typing import Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
class GPUExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
# Instantiate the worker and load the model to GPU.
self._init_worker()
# Profile the memory usage and initialize the cache.
self._init_cache()
def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
The engine first profiles the existing memory usage.
Then, it allocates the remaining memory for KV blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_gpu_blocks, num_cpu_blocks = (
self.driver_worker.profile_num_available_blocks(
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.
gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
))
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self.driver_worker.warm_up_model()
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output
async def check_health_async(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
from typing import Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
self.cache_config.num_cpu_blocks = 0
# Instantiate the worker and load the model to the device.
self._init_worker()
def _init_worker(self):
from vllm.worker.neuron_worker import NeuronWorker
self.driver_worker = NeuronWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.")
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
def list_loras(self) -> List[int]:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
def check_health(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return
import asyncio
import copy
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async, set_cuda_visible_devices)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class RayGPUExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
# Profile the memory usage and initialize the cache.
self._init_cache()
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: RayWorkerVllm = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = []
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
else:
# Else, added to the list of workers.
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
# Get the set of GPU IDs used on each node.
driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
node_workers[driver_node_id].append(0)
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
kv_cache_dtype = self.cache_config.cache_dtype
# Initialize the actual workers with the Worker class.
for rank, (worker, (node_id, _)) in enumerate(
zip(self.workers, worker_node_and_gpu_ids),
start=1,
):
local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote(
lambda rank=rank, local_rank=local_rank: Worker(
model_config,
parallel_config,
scheduler_config,
device_config,
local_rank,
rank,
distributed_init_method,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
))
# Initialize the driver worker with the Worker class.
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
driver_local_rank,
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=True,
)
self._run_workers("init_device")
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self._run_workers("warm_up_model")
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> List[int]:
return self._run_workers("list_loras")
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.bind(input_data)
for worker in self.workers
])
return forward_dag.experimental_compile()
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def _check_if_any_actor_is_dead(self):
if not self.workers:
return
dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
raise RuntimeError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")
class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Run the driver worker asynchronously.
driver_executor = make_async(getattr(self.driver_worker, method))
coros.append(driver_executor(*driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
async def check_health_async(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM.""" """Logging configuration for vLLM."""
import logging import logging
import sys
import os import os
import sys
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
......
# pylint: disable=unused-argument # pylint: disable=unused-argument
import inspect
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,20 +11,20 @@ from transformers import PretrainedConfig ...@@ -10,20 +11,20 @@ from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather,
)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
MergedColumnParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim from vllm.model_executor.parallel_utils.utils import (
split_tensor_along_last_dim)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
...@@ -84,7 +85,8 @@ def _apply_lora_packed_nslice( ...@@ -84,7 +85,8 @@ def _apply_lora_packed_nslice(
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size) indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size) output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...), where n is number of slices output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
""" """
org_output = output org_output = output
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
...@@ -113,8 +115,11 @@ class LoRAMapping: ...@@ -113,8 +115,11 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, def create_lora_weights(
model_config: PretrainedConfig) -> None: self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
"""Initializes lora matrices.""" """Initializes lora matrices."""
... ...
...@@ -143,6 +148,13 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -143,6 +148,13 @@ class BaseLayerWithLoRA(nn.Module):
"""Sets the mapping indices.""" """Sets the mapping indices."""
... ...
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...@@ -277,12 +289,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -277,12 +289,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.indices[:self.indices_len[0]], 0, 1.0) self.indices[:self.indices_len[0]], 0, 1.0)
return full_output.view_as(full_output_org) return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is VocabParallelEmbedding
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ColumnParallelLinear) -> None: def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size()
def create_lora_weights( def create_lora_weights(
self, self,
...@@ -308,7 +327,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -308,7 +327,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices: Optional[torch.Tensor] = None self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[1] self.output_dim = self.lora_b_stacked.shape[2]
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
...@@ -322,7 +341,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -322,7 +341,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True) lora_a.T, non_blocking=True)
...@@ -382,6 +406,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -382,6 +406,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def linear_weights(self): def linear_weights(self):
return self.base_layer.linear_weights return self.base_layer.linear_weights
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 1)
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices) """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
...@@ -484,8 +516,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -484,8 +516,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) )
return output return output
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is MergedColumnParallelLinear and len(
packed_modules_list) == 2
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
self.tp_size = get_tensor_model_parallel_world_size()
self.q_proj_total_size = (self.base_layer.total_num_heads *
self.base_layer.head_size)
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size)
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 1
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices) """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj). (q_proj + k_proj + v_proj -> qkv_proj).
...@@ -653,6 +757,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): ...@@ -653,6 +757,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
) )
return output return output
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 3
class RowParallelLinearWithLoRA(BaseLayerWithLoRA): class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...@@ -779,12 +890,18 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -779,12 +890,18 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def weight(self): def weight(self):
return self.base_layer.weight return self.base_layer.weight
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is RowParallelLinear
class SamplerWithLoRA(BaseLayerWithLoRA):
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def __init__( def __init__(
self, self,
base_layer: Sampler, base_layer: LogitsProcessor,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
...@@ -796,13 +913,17 @@ class SamplerWithLoRA(BaseLayerWithLoRA): ...@@ -796,13 +913,17 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
self.device = device self.device = device
@property @property
def logits_as_hidden_states(self): def logits_as_input(self):
return self.base_layer.logits_as_hidden_states return self.base_layer.logits_as_input
@property @property
def vocab_size(self): def vocab_size(self):
return self.base_layer.vocab_size return self.base_layer.vocab_size
@property
def scale(self):
return self.base_layer.scale
@property @property
def org_vocab_size(self): def org_vocab_size(self):
return self.base_layer.org_vocab_size return self.base_layer.org_vocab_size
...@@ -819,9 +940,8 @@ class SamplerWithLoRA(BaseLayerWithLoRA): ...@@ -819,9 +940,8 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
) -> None: ) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024: if 32000 < self.base_layer.vocab_size > 33024:
raise ValueError( raise ValueError("When using LoRA, vocab size must be "
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024" "32000 >= vocab_size <= 33024")
)
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
...@@ -896,7 +1016,7 @@ class SamplerWithLoRA(BaseLayerWithLoRA): ...@@ -896,7 +1016,7 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
embedding: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Optional[torch.Tensor]:
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None: if embedding_bias is not None:
...@@ -945,35 +1065,43 @@ class SamplerWithLoRA(BaseLayerWithLoRA): ...@@ -945,35 +1065,43 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs) return type(self.base_layer).forward(self, *args, **kwargs)
@classmethod
def from_layer( def can_replace_layer(cls, source_layer: nn.Module,
layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List,
max_loras: int, model_config: Optional[PretrainedConfig]) -> bool:
lora_config: LoRAConfig, # Special handling for the LogitsProcessor.
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: return False
supported_layer_types = {
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA, _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
QKVParallelLinear: QKVParallelLinearWithLora, cls
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, for cls in globals().values() if inspect.isclass(cls)
RowParallelLinear: RowParallelLinearWithLoRA, and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
} }
for src_layer_type, lora_layer_type in supported_layer_types.items():
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer) def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config) ret.create_lora_weights(max_loras, lora_config, model_config)
return ret return ret
return layer return layer
def from_layer_sampler( def from_layer_logits_processor(
layer: Sampler, layer: LogitsProcessor,
lm_head: ParallelLMHead, lm_head: ParallelLMHead,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None, model_config: Optional[PretrainedConfig] = None,
) -> SamplerWithLoRA: ) -> LogitsProcessorWithLoRA:
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.device) lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config) ret.create_lora_weights(max_loras, lora_config, model_config)
return ret return ret
from typing import List, Optional from typing import List, Optional
import torch import torch
from vllm.utils import in_wsl
from vllm.utils import is_pin_memory_available
class LoRALayerWeights: class LoRALayerWeights:
...@@ -64,7 +65,7 @@ class LoRALayerWeights: ...@@ -64,7 +65,7 @@ class LoRALayerWeights:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and not in_wsl() pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank], lora_a = torch.zeros([input_dim, rank],
dtype=dtype, dtype=dtype,
device=device, device=device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment