Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
import inspect import inspect
from typing import Dict, List, Union from typing import Dict, List, Optional, Union
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo from vllm.utils import supports_dynamo
logger = init_logger(__name__)
def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def support_torch_compile(
cls: Optional[type] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
Usage 1: use directly as a decorator without arguments:
```python
@support_torch_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
Usage 2: use as a decorator with arguments:
```python
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers. integer or a list of integers.
Depending on the value of arguments: if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
of the `forward` method, based on the following default rules:
- if the argument is annotated as `torch.Tensor` or
`Optional[torch.Tensor]`, the first dimension will be
marked as dynamic.
- if the argument is annotated as `IntermediateTensors`, the first
dimension of all the tensors in the intermediate tensors
will be marked as dynamic.
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument - if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic. will be marked as dynamic.
...@@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]): ...@@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
if not hasattr(cls, 'forward'): if not hasattr(cls, 'forward'):
raise TypeError("decorated class should have a forward method.") raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward) sig = inspect.signature(cls.forward)
for k in dynamic_arg_dims: inferred_dynamic_arg_dims = dynamic_arg_dims
if inferred_dynamic_arg_dims is None:
inferred_dynamic_arg_dims = {}
for k, v in sig.parameters.items():
if v.annotation in [
torch.Tensor, Optional[torch.Tensor],
IntermediateTensors, Optional[IntermediateTensors]
]:
inferred_dynamic_arg_dims[k] = 0
logger.debug(("Inferred dynamic dimensions for "
"forward method of %s: %s"), cls,
list(inferred_dynamic_arg_dims.keys()))
if len(inferred_dynamic_arg_dims) == 0:
raise ValueError(
"No dynamic dimensions found in the forward method of "
f"{cls}. Please provide dynamic_arg_dims explicitly.")
for k in inferred_dynamic_arg_dims:
if k not in sig.parameters: if k not in sig.parameters:
raise ValueError( raise ValueError(
f"Argument {k} not found in the forward method of {cls}") f"Argument {k} not found in the forward method of {cls}")
return _support_torch_compile(cls, dynamic_arg_dims) return _support_torch_compile(cls, inferred_dynamic_arg_dims)
if cls is not None:
# use `support_torch_compile` as a decorator without arguments
assert isinstance(cls, type)
return cls_decorator_helper(cls)
return cls_decorator_helper return cls_decorator_helper
......
import enum import enum
import json import json
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Optional, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -17,8 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, ...@@ -17,8 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu, is_hip, is_openvino, is_xpu, print_warning_once)
print_warning_once)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -33,6 +32,11 @@ logger = init_logger(__name__) ...@@ -33,6 +32,11 @@ logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding"]
# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
class ModelConfig: class ModelConfig:
"""Configuration for the model. """Configuration for the model.
...@@ -40,7 +44,11 @@ class ModelConfig: ...@@ -40,7 +44,11 @@ class ModelConfig:
Args: Args:
model: Name or path of the huggingface model to use. model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified. output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use. tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and available, "slow" will always use the slow tokenizer, and
...@@ -108,6 +116,7 @@ class ModelConfig: ...@@ -108,6 +116,7 @@ class ModelConfig:
def __init__(self, def __init__(self,
model: str, model: str,
task: Union[TaskOption, _Task],
tokenizer: str, tokenizer: str,
tokenizer_mode: str, tokenizer_mode: str,
trust_remote_code: bool, trust_remote_code: bool,
...@@ -205,9 +214,15 @@ class ModelConfig: ...@@ -205,9 +214,15 @@ class ModelConfig:
self.is_attention_free = self._init_attention_free() self.is_attention_free = self._init_attention_free()
self.has_inner_state = self._init_has_inner_state() self.has_inner_state = self._init_has_inner_state()
self.override_neuron_config = override_neuron_config if is_neuron( if current_platform.is_neuron():
) else None self.override_neuron_config = override_neuron_config
self._verify_embedding_mode() else:
self.override_neuron_config = None
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
self._verify_bnb_config() self._verify_bnb_config()
...@@ -241,18 +256,44 @@ class ModelConfig: ...@@ -241,18 +256,44 @@ class ModelConfig:
"either 'auto', 'slow' or 'mistral'.") "either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None: def _resolve_task(
architectures = getattr(self.hf_config, "architectures", []) self,
task_option: Union[TaskOption, _Task],
hf_config: PretrainedConfig,
) -> Tuple[Set[_Task], _Task]:
if task_option == "draft":
return {"draft"}, "draft"
architectures = getattr(hf_config, "architectures", [])
task_support: Dict[_Task, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
}
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)
if task_option == "auto":
selected_task = next(iter(supported_tasks_lst))
# TODO: Allow the same model architecture to be specified as either if len(supported_tasks) > 1:
# generation or embedding model logger.info(
if "Phi3VForCausalLM" in architectures: "This model supports multiple tasks: %s. "
# Match both remote and local names "Defaulting to '%s'.", supported_tasks, selected_task)
embedding_mode = "/VLM2Vec" in self.model
else: else:
embedding_mode = ModelRegistry.is_embedding_model(architectures) if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)
self.embedding_mode = embedding_mode selected_task = task_option
return supported_tasks, selected_task
def _parse_quant_hf_config(self): def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None) quant_cfg = getattr(self.hf_config, "quantization_config", None)
...@@ -337,7 +378,7 @@ class ModelConfig: ...@@ -337,7 +378,7 @@ class ModelConfig:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.") " is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True envs.VLLM_USE_TRITON_AWQ = True
if is_neuron( if current_platform.is_neuron(
) and self.quantization not in neuron_supported_quantization: ) and self.quantization not in neuron_supported_quantization:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
...@@ -410,7 +451,7 @@ class ModelConfig: ...@@ -410,7 +451,7 @@ class ModelConfig:
# Async postprocessor is not necessary with embedding mode # Async postprocessor is not necessary with embedding mode
# since there is no token generation # since there is no token generation
if self.embedding_mode: if self.task == "embedding":
self.use_async_output_proc = False self.use_async_output_proc = False
# Reminder: Please update docs/source/serving/compatibility_matrix.rst # Reminder: Please update docs/source/serving/compatibility_matrix.rst
...@@ -591,11 +632,6 @@ class ModelConfig: ...@@ -591,11 +632,6 @@ class ModelConfig:
(hasattr(self.hf_config, "text_config") and getattr( (hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False))) self.hf_config.text_config, "is_encoder_decoder", False)))
@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode
@property @property
def is_multimodal_model(self) -> bool: def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None return self.multimodal_config is not None
...@@ -952,6 +988,7 @@ class SchedulerConfig: ...@@ -952,6 +988,7 @@ class SchedulerConfig:
"""Scheduler configuration. """Scheduler configuration.
Args: Args:
task: The task to use the model for.
max_num_batched_tokens: Maximum number of tokens to be processed in max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration. a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single max_num_seqs: Maximum number of sequences to be processed in a single
...@@ -966,7 +1003,6 @@ class SchedulerConfig: ...@@ -966,7 +1003,6 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt. prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens. on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows: recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than We use recomputation by default since it incurs lower overhead than
...@@ -981,13 +1017,13 @@ class SchedulerConfig: ...@@ -981,13 +1017,13 @@ class SchedulerConfig:
""" """
def __init__(self, def __init__(self,
task: _Task,
max_num_batched_tokens: Optional[int], max_num_batched_tokens: Optional[int],
max_num_seqs: int, max_num_seqs: int,
max_model_len: int, max_model_len: int,
num_lookahead_slots: int = 0, num_lookahead_slots: int = 0,
delay_factor: float = 0.0, delay_factor: float = 0.0,
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
embedding_mode: bool = False,
is_multimodal_model: bool = False, is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None, preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1, num_scheduler_steps: int = 1,
...@@ -1011,7 +1047,7 @@ class SchedulerConfig: ...@@ -1011,7 +1047,7 @@ class SchedulerConfig:
# for higher throughput. # for higher throughput.
max_num_batched_tokens = max(max_model_len, 2048) max_num_batched_tokens = max(max_model_len, 2048)
if embedding_mode: if task == "embedding":
# For embedding, choose specific value for higher throughput # For embedding, choose specific value for higher throughput
max_num_batched_tokens = max( max_num_batched_tokens = max(
max_num_batched_tokens, max_num_batched_tokens,
...@@ -1031,12 +1067,12 @@ class SchedulerConfig: ...@@ -1031,12 +1067,12 @@ class SchedulerConfig:
"Chunked prefill is enabled with max_num_batched_tokens=%d.", "Chunked prefill is enabled with max_num_batched_tokens=%d.",
self.max_num_batched_tokens) self.max_num_batched_tokens)
self.task: Final = task
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.num_lookahead_slots = num_lookahead_slots self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs self.multi_step_stream_outputs = multi_step_stream_outputs
...@@ -1086,7 +1122,7 @@ class DeviceConfig: ...@@ -1086,7 +1122,7 @@ class DeviceConfig:
# Automated device type detection # Automated device type detection
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
self.device_type = "cuda" self.device_type = "cuda"
elif is_neuron(): elif current_platform.is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_openvino(): elif is_openvino():
self.device_type = "openvino" self.device_type = "openvino"
...@@ -1248,6 +1284,7 @@ class SpeculativeConfig: ...@@ -1248,6 +1284,7 @@ class SpeculativeConfig:
ngram_prompt_lookup_min = 0 ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig( draft_model_config = ModelConfig(
model=speculative_model, model=speculative_model,
task="draft",
tokenizer=target_model_config.tokenizer, tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode, tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code, trust_remote_code=target_model_config.trust_remote_code,
...@@ -1381,11 +1418,11 @@ class SpeculativeConfig: ...@@ -1381,11 +1418,11 @@ class SpeculativeConfig:
else: else:
speculative_draft_tensor_parallel_size = \ speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1: elif speculative_draft_tensor_parallel_size not in (
# TODO(wooyeon): allow tp values larger than 1 1, target_parallel_config.tensor_parallel_size):
raise ValueError( raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be " f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1") f"other value than 1 or target model tensor_parallel_size")
draft_parallel_config = ParallelConfig( draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config. pipeline_parallel_size=target_parallel_config.
......
...@@ -7,7 +7,7 @@ from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, ...@@ -7,7 +7,7 @@ from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import (BlockPool, NaiveBlock, from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator) NaiveBlockAllocator)
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
PrefixHash = int PrefixHash = int
......
import enum
from abc import ABC, abstractmethod
from typing import OrderedDict
from vllm.block import PhysicalTokenBlock
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_hash: int) -> bool:
pass
@abstractmethod
def evict(self) -> PhysicalTokenBlock:
"""Runs the eviction algorithm and returns the evicted block"""
pass
@abstractmethod
def add(self, block: PhysicalTokenBlock):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def remove(self, block_hash: int) -> PhysicalTokenBlock:
"""Simply removes the block with the hash value block_hash from the
evictor. Caller is responsible for making sure that block_hash is
contained in the evictor before calling remove. Should be used to
"bring back" blocks that have been freed but not evicted yet.
"""
pass
@property
@abstractmethod
def num_blocks(self) -> int:
pass
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def __init__(self):
self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict()
def __contains__(self, block_hash: int) -> bool:
return block_hash in self.free_table
def evict(self) -> PhysicalTokenBlock:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _, block in self.free_table.items():
if evicted_block.last_accessed < block.last_accessed:
break
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block = block
self.free_table.pop(evicted_block.block_hash)
evicted_block.computed = False
return evicted_block
def add(self, block: PhysicalTokenBlock):
self.free_table[block.block_hash] = block
def remove(self, block_hash: int) -> PhysicalTokenBlock:
if block_hash not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
block: PhysicalTokenBlock = self.free_table[block_hash]
self.free_table.pop(block_hash)
return block
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
...@@ -313,7 +313,7 @@ class Scheduler: ...@@ -313,7 +313,7 @@ class Scheduler:
self.lora_config = lora_config self.lora_config = lora_config
version = "selfattn" version = "selfattn"
if (self.scheduler_config.embedding_mode if (self.scheduler_config.task == "embedding"
or self.cache_config.is_attention_free): or self.cache_config.is_attention_free):
version = "placeholder" version = "placeholder"
......
...@@ -7,7 +7,7 @@ It takes over the control of the distributed environment from PyTorch. ...@@ -7,7 +7,7 @@ It takes over the control of the distributed environment from PyTorch.
The typical workflow is: The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment. - call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups. initialize the model parallel groups.
- any code dealing with the distributed stuff - any code dealing with the distributed stuff
...@@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline ...@@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
steps. steps.
""" """
import contextlib import contextlib
import gc
import pickle import pickle
import weakref import weakref
from collections import namedtuple from collections import namedtuple
...@@ -1129,6 +1130,19 @@ def destroy_distributed_environment(): ...@@ -1129,6 +1130,19 @@ def destroy_distributed_environment():
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
if shutdown_ray:
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
if not current_platform.is_cpu():
torch.cuda.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
""" """
This is a collective operation that returns if each rank is in the same node This is a collective operation that returns if each rank is in the same node
......
...@@ -3,7 +3,7 @@ import dataclasses ...@@ -3,7 +3,7 @@ import dataclasses
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast) Tuple, Type, Union, cast, get_args)
import torch import torch
...@@ -12,10 +12,12 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, ...@@ -12,10 +12,12 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat, DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig, LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig) SpeculativeConfig, TaskOption, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -84,6 +86,7 @@ class EngineArgs: ...@@ -84,6 +86,7 @@ class EngineArgs:
model: str = 'facebook/opt-125m' model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None tokenizer: Optional[str] = None
task: TaskOption = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto' tokenizer_mode: str = 'auto'
trust_remote_code: bool = False trust_remote_code: bool = False
...@@ -198,6 +201,15 @@ class EngineArgs: ...@@ -198,6 +201,15 @@ class EngineArgs:
type=str, type=str,
default=EngineArgs.model, default=EngineArgs.model,
help='Name or path of the huggingface model to use.') help='Name or path of the huggingface model to use.')
parser.add_argument(
'--task',
default=EngineArgs.task,
choices=get_args(TaskOption),
help='The task to use the model for. Each vLLM instance only '
'supports one task, even if the same model can be used for '
'multiple tasks. When the model only supports one task, "auto" '
'can be used to select it; otherwise, you must specify explicitly '
'which task to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=nullable_str, type=nullable_str,
...@@ -418,7 +430,11 @@ class EngineArgs: ...@@ -418,7 +430,11 @@ class EngineArgs:
help='The fraction of GPU memory to be used for the model ' help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of ' 'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, ' '0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9.') 'will use the default value of 0.9. This is a global gpu memory '
'utilization limit, for example if 50%% of the gpu memory is '
'already used before vLLM starts and --gpu-memory-utilization is '
'set to 0.9, then only 40%% of the gpu memory will be allocated '
'to the model executor.')
parser.add_argument( parser.add_argument(
'--num-gpu-blocks-override', '--num-gpu-blocks-override',
type=int, type=int,
...@@ -838,6 +854,7 @@ class EngineArgs: ...@@ -838,6 +854,7 @@ class EngineArgs:
def create_model_config(self) -> ModelConfig: def create_model_config(self) -> ModelConfig:
return ModelConfig( return ModelConfig(
model=self.model, model=self.model,
task=self.task,
# We know this is not None because we set it in __post_init__ # We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer), tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode, tokenizer_mode=self.tokenizer_mode,
...@@ -909,6 +926,8 @@ class EngineArgs: ...@@ -909,6 +926,8 @@ class EngineArgs:
"supported for multimodal models and has been disabled.") "supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False self.enable_prefix_caching = False
maybe_register_config_serialize_by_value(self.trust_remote_code)
cache_config = CacheConfig( cache_config = CacheConfig(
# neuron needs block_size = max_model_len # neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
...@@ -1026,13 +1045,13 @@ class EngineArgs: ...@@ -1026,13 +1045,13 @@ class EngineArgs:
" please file an issue with detailed information.") " please file an issue with detailed information.")
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
task=model_config.task,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs, max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len, max_model_len=model_config.max_model_len,
num_lookahead_slots=num_lookahead_slots, num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor, delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
is_multimodal_model=model_config.is_multimodal_model, is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode, preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps, num_scheduler_steps=self.num_scheduler_steps,
......
import time import time
from collections import Counter as collectionsCounter
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -43,8 +44,10 @@ from vllm.pooling_params import PoolingParams ...@@ -43,8 +44,10 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, ParallelSampleSequenceGroup, Sequence,
SequenceGroupOutput, SequenceStatus) SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
...@@ -344,7 +347,7 @@ class LLMEngine: ...@@ -344,7 +347,7 @@ class LLMEngine:
observability_config=self.observability_config, observability_config=self.observability_config,
) )
if not self.model_config.embedding_mode: if self.model_config.task != "embedding":
self._initialize_kv_caches() self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
...@@ -473,6 +476,8 @@ class LLMEngine: ...@@ -473,6 +476,8 @@ class LLMEngine:
), ),
)) ))
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
...@@ -641,7 +646,10 @@ class LLMEngine: ...@@ -641,7 +646,10 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> SequenceGroup:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
self._validate_model_inputs(processed_inputs) self._validate_model_inputs(processed_inputs)
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
...@@ -695,6 +703,8 @@ class LLMEngine: ...@@ -695,6 +703,8 @@ class LLMEngine:
min_cost_scheduler = self.scheduler[costs.index(min(costs))] min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group) min_cost_scheduler.add_seq_group(seq_group)
return seq_group
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
...@@ -710,7 +720,7 @@ class LLMEngine: ...@@ -710,7 +720,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> Optional[SequenceGroup]:
... ...
@overload @overload
...@@ -724,7 +734,7 @@ class LLMEngine: ...@@ -724,7 +734,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> Optional[SequenceGroup]:
... ...
@deprecate_kwargs( @deprecate_kwargs(
...@@ -743,7 +753,7 @@ class LLMEngine: ...@@ -743,7 +753,7 @@ class LLMEngine:
priority: int = 0, priority: int = 0,
*, *,
inputs: Optional[PromptType] = None, # DEPRECATED inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> Optional[SequenceGroup]:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the The request is added to the request pool and will be processed by the
...@@ -787,6 +797,22 @@ class LLMEngine: ...@@ -787,6 +797,22 @@ class LLMEngine:
>>> # continue the request processing >>> # continue the request processing
>>> ... >>> ...
""" """
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
prompt=prompt,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
inputs=inputs,
)
return None
if inputs is not None: if inputs is not None:
prompt = inputs prompt = inputs
assert prompt is not None and params is not None assert prompt is not None and params is not None
...@@ -817,7 +843,7 @@ class LLMEngine: ...@@ -817,7 +843,7 @@ class LLMEngine:
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs") "mm_processor_kwargs")
self._add_processed_request( return self._add_processed_request(
request_id=request_id, request_id=request_id,
processed_inputs=processed_inputs, processed_inputs=processed_inputs,
params=params, params=params,
...@@ -1116,7 +1142,7 @@ class LLMEngine: ...@@ -1116,7 +1142,7 @@ class LLMEngine:
seq_group.metrics.model_execute_time = ( seq_group.metrics.model_execute_time = (
o.model_execute_time) o.model_execute_time)
if self.model_config.embedding_mode: if self.model_config.task == "embedding":
self._process_sequence_group_outputs(seq_group, output) self._process_sequence_group_outputs(seq_group, output)
else: else:
self.output_processor.process_prompt_logprob(seq_group, output) self.output_processor.process_prompt_logprob(seq_group, output)
...@@ -1134,7 +1160,9 @@ class LLMEngine: ...@@ -1134,7 +1160,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create( request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs) seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
...@@ -1174,7 +1202,9 @@ class LLMEngine: ...@@ -1174,7 +1202,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create( request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs) seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
...@@ -1193,7 +1223,10 @@ class LLMEngine: ...@@ -1193,7 +1223,10 @@ class LLMEngine:
continue continue
request_output = RequestOutputFactory.create( request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs) seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs,
)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
...@@ -1212,7 +1245,7 @@ class LLMEngine: ...@@ -1212,7 +1245,7 @@ class LLMEngine:
skip) skip)
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs, finished_before)
return None return None
...@@ -1617,6 +1650,25 @@ class LLMEngine: ...@@ -1617,6 +1650,25 @@ class LLMEngine:
n_requests: List[int] = [] n_requests: List[int] = []
finished_reason_requests: List[str] = [] finished_reason_requests: List[str] = []
# Lora requests
running_lora_adapters = dict(
collectionsCounter([
running_request.lora_request.lora_name
for scheduler in self.scheduler
for running_request in scheduler.running
if running_request.lora_request
]))
waiting_lora_adapters = dict(
collectionsCounter([
waiting_request.lora_request.lora_name
for scheduler in self.scheduler
for waiting_request in scheduler.waiting
if waiting_request.lora_request
]))
max_lora_stat = "0"
if self.lora_config:
max_lora_stat = str(self.lora_config.max_loras)
# NOTE: This loop assumes prefill seq_groups are before # NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups. # decode seq_groups in scheduled_seq_groups.
if scheduler_outputs is not None: if scheduler_outputs is not None:
...@@ -1666,6 +1718,15 @@ class LLMEngine: ...@@ -1666,6 +1718,15 @@ class LLMEngine:
# TPOTs. # TPOTs.
latency = seq_group.get_last_latency(now) latency = seq_group.get_last_latency(now)
time_per_output_tokens_iter.append(latency) time_per_output_tokens_iter.append(latency)
if seq_group.state.current_step == 0:
# For async_output_proc, the do_log_stats()
# is called following init_multi_step(), which
# sets the current_step to zero.
actual_num_batched_tokens +=\
seq_group.state.num_steps - 1
else:
actual_num_batched_tokens +=\
seq_group.state.current_step - 1
# Because of chunked prefill, we can have a single sequence # Because of chunked prefill, we can have a single sequence
# group that does multiple prompt_runs. To prevent logging # group that does multiple prompt_runs. To prevent logging
...@@ -1738,7 +1799,9 @@ class LLMEngine: ...@@ -1738,7 +1799,9 @@ class LLMEngine:
num_generation_tokens_requests=num_generation_tokens_requests, num_generation_tokens_requests=num_generation_tokens_requests,
n_requests=n_requests, n_requests=n_requests,
finished_reason_requests=finished_reason_requests, finished_reason_requests=finished_reason_requests,
) max_lora=str(max_lora_stat),
waiting_lora_adapters=list(waiting_lora_adapters.keys()),
running_lora_adapters=list(running_lora_adapters.keys()))
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request) return self.model_executor.add_lora(lora_request)
...@@ -1786,11 +1849,18 @@ class LLMEngine: ...@@ -1786,11 +1849,18 @@ class LLMEngine:
def is_tracing_enabled(self) -> bool: def is_tracing_enabled(self) -> bool:
return self.tracer is not None return self.tracer is not None
def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None: def do_tracing(self,
scheduler_outputs: SchedulerOutputs,
finished_before: Optional[List[int]] = None) -> None:
if self.tracer is None: if self.tracer is None:
return return
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: for idx, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
# Skip double tracing when using async output proc
if finished_before and idx in finished_before:
continue
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished(): if seq_group.is_finished():
self.create_trace_span(seq_group) self.create_trace_span(seq_group)
...@@ -1855,9 +1925,6 @@ class LLMEngine: ...@@ -1855,9 +1925,6 @@ class LLMEngine:
def is_encoder_decoder_model(self): def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model() return self.input_preprocessor.is_encoder_decoder_model()
def is_embedding_model(self):
return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]): EncoderDecoderInputs]):
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
......
...@@ -34,7 +34,11 @@ class Metrics: ...@@ -34,7 +34,11 @@ class Metrics:
See https://prometheus.github.io/client_python/multiprocess/ for more See https://prometheus.github.io/client_python/multiprocess/ for more
details on limitations. details on limitations.
""" """
labelname_finish_reason = "finished_reason" labelname_finish_reason = "finished_reason"
labelname_waiting_lora_adapters = "waiting_lora_adapters"
labelname_running_lora_adapters = "running_lora_adapters"
labelname_max_lora = "max_lora"
_gauge_cls = prometheus_client.Gauge _gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter _counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram _histogram_cls = prometheus_client.Histogram
...@@ -55,6 +59,16 @@ class Metrics: ...@@ -55,6 +59,16 @@ class Metrics:
documentation="Number of requests waiting to be processed.", documentation="Number of requests waiting to be processed.",
labelnames=labelnames, labelnames=labelnames,
multiprocess_mode="sum") multiprocess_mode="sum")
self.gauge_lora_info = self._gauge_cls(
name="vllm:lora_requests_info",
documentation="Running stats on lora requests.",
labelnames=[
self.labelname_running_lora_adapters,
self.labelname_max_lora,
self.labelname_waiting_lora_adapters,
],
multiprocess_mode="livemostrecent",
)
self.gauge_scheduler_swapped = self._gauge_cls( self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped", name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.", documentation="Number of requests swapped to CPU.",
...@@ -426,6 +440,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -426,6 +440,9 @@ class PrometheusStatLogger(StatLoggerBase):
for datum in data: for datum in data:
histogram.labels(**self.labels).observe(datum) histogram.labels(**self.labels).observe(datum)
def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None:
gauge.labels(**data).set(1)
def _log_prometheus(self, stats: Stats) -> None: def _log_prometheus(self, stats: Stats) -> None:
# System state data # System state data
self._log_gauge(self.metrics.gauge_scheduler_running, self._log_gauge(self.metrics.gauge_scheduler_running,
...@@ -442,7 +459,17 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -442,7 +459,17 @@ class PrometheusStatLogger(StatLoggerBase):
stats.cpu_prefix_cache_hit_rate) stats.cpu_prefix_cache_hit_rate)
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
stats.gpu_prefix_cache_hit_rate) stats.gpu_prefix_cache_hit_rate)
# Including max-lora in metric, in future this property of lora
# config maybe extended to be dynamic.
lora_info = {
self.metrics.labelname_running_lora_adapters:
",".join(stats.running_lora_adapters),
self.metrics.labelname_waiting_lora_adapters:
",".join(stats.waiting_lora_adapters),
self.metrics.labelname_max_lora:
stats.max_lora,
}
self._log_gauge_string(self.metrics.gauge_lora_info, lora_info)
# Iteration level data # Iteration level data
self._log_counter(self.metrics.counter_num_preemption, self._log_counter(self.metrics.counter_num_preemption,
stats.num_preemption_iter) stats.num_preemption_iter)
......
...@@ -51,6 +51,9 @@ class Stats: ...@@ -51,6 +51,9 @@ class Stats:
num_generation_tokens_requests: List[int] num_generation_tokens_requests: List[int]
n_requests: List[int] n_requests: List[int]
finished_reason_requests: List[str] finished_reason_requests: List[str]
waiting_lora_adapters: List[str]
running_lora_adapters: List[str]
max_lora: str
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
......
...@@ -204,8 +204,20 @@ class MQLLMEngineClient(EngineClient): ...@@ -204,8 +204,20 @@ class MQLLMEngineClient(EngineClient):
# (and record only the first one) # (and record only the first one)
if is_engine_errored and not self._errored_with: if is_engine_errored and not self._errored_with:
self._errored_with = exception self._errored_with = exception
# If engine is errored, no matter the type of exception
# it will no longer be able to receive new requests,
# therefore we have to inform that the current
# processed requests failed as well. Send back a dead
# engine error give this feedback and also give a
# 'hint' to the server to shutdown next.
exception = self.dead_error
if request_id is None: if request_id is None:
# If request_id is None, then the engine raised an
# exception for a batch, and we may not know the
# request that caused it, neither if it was actually
# caused by any of them (e.g. CUDA OOM). Therefore we
# broadcast the same exception for all requests.
for queue_i in tuple(self.output_queues.values()): for queue_i in tuple(self.output_queues.values()):
queue_i.put_nowait(exception) queue_i.put_nowait(exception)
else: else:
......
...@@ -8,7 +8,7 @@ from typing import Iterator, List, Optional, Union ...@@ -8,7 +8,7 @@ from typing import Iterator, List, Optional, Union
import cloudpickle import cloudpickle
import zmq import zmq
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
...@@ -21,12 +21,17 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -21,12 +21,17 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
if VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine
else:
from vllm.engine.llm_engine import LLMEngine
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig] SchedulerConfig, LoRAConfig]
...@@ -136,14 +141,16 @@ class MQLLMEngine: ...@@ -136,14 +141,16 @@ class MQLLMEngine:
executor_class = LLMEngine._get_executor_cls(engine_config) executor_class = LLMEngine._get_executor_cls(engine_config)
return cls( use_async_sockets = (engine_config.model_config.use_async_output_proc
ipc_path=ipc_path, and not VLLM_USE_V1)
use_async_sockets=engine_config.model_config.use_async_output_proc,
**engine_config.to_dict(), return cls(ipc_path=ipc_path,
executor_class=executor_class, use_async_sockets=use_async_sockets,
log_requests=not engine_args.disable_log_requests, **engine_config.to_dict(),
log_stats=not engine_args.disable_log_stats, executor_class=executor_class,
usage_context=usage_context) log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
def start(self): def start(self):
try: try:
......
...@@ -59,7 +59,7 @@ class EngineClient(ABC): ...@@ -59,7 +59,7 @@ class EngineClient(ABC):
async def beam_search( async def beam_search(
self, self,
prompt: Union[PromptType, List[int]], prompt: Union[str, List[int]],
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
...@@ -71,9 +71,13 @@ class EngineClient(ABC): ...@@ -71,9 +71,13 @@ class EngineClient(ABC):
length_penalty = params.length_penalty length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None) tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance( if isinstance(prompt, str):
prompt, list) else tokenizer.encode(prompt) tokenized_prompt = tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt) prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)
sort_beams_key = create_sort_beams_key_function( sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty) tokenizer.eos_token_id, length_penalty)
...@@ -81,7 +85,11 @@ class EngineClient(ABC): ...@@ -81,7 +85,11 @@ class EngineClient(ABC):
beam_search_params = SamplingParams(logprobs=2 * beam_width, beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1, max_tokens=1,
temperature=temperature) temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
logprobs=[],
cum_logprob=0)
]
completed = [] completed = []
for _ in range(max_tokens): for _ in range(max_tokens):
...@@ -114,6 +122,7 @@ class EngineClient(ABC): ...@@ -114,6 +122,7 @@ class EngineClient(ABC):
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence( new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob) logprob_obj.logprob)
...@@ -131,22 +140,22 @@ class EngineClient(ABC): ...@@ -131,22 +140,22 @@ class EngineClient(ABC):
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
for beam in best_beams: for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) beam.text = tokenizer.decode(beam.tokens[tokenized_length:])
beam_search_output = RequestOutput( beam_search_output = RequestOutput(
request_id=request_id, request_id=request_id,
prompt=prompt, prompt=prompt_text,
outputs=[ outputs=[
CompletionOutput( CompletionOutput(
text=beam.text, text=beam.text,
cumulative_logprob=beam.cum_logprob, cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens, token_ids=beam.tokens[tokenized_length:],
index=i, index=i,
logprobs=beam.cum_logprob, logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams) ) for (i, beam) in enumerate(best_beams)
], ],
finished=True, finished=True,
prompt_token_ids=tokenizedPrompt, prompt_token_ids=tokenized_prompt,
prompt_logprobs=None) prompt_logprobs=None)
yield beam_search_output yield beam_search_output
......
...@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod ...@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from functools import lru_cache, partial from functools import lru_cache, partial
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Mapping, Optional, Tuple, TypeVar, Union, cast) Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -33,6 +33,7 @@ from vllm.multimodal.utils import (async_get_and_parse_audio, ...@@ -33,6 +33,7 @@ from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image, async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image) get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -58,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ...@@ -58,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part.""" """The type of the content part."""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
}
"""
image_url: Required[str]
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
}
"""
audio_url: Required[str]
ChatCompletionContentPartParam: TypeAlias = Union[ ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam] CustomChatCompletionContentPartParam,
CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam, str]
class CustomChatCompletionMessageParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False):
...@@ -386,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) ...@@ -386,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
"""
Parses a given multi modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
Returns:
A tuple (part_type, content) where:
- part_type: Type of the part (e.g., 'text', 'image_url').
- content: Parsed content (e.g., text, image URL).
Raises:
ValueError: If the 'type' field is missing and no direct URL is found.
"""
assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
content = MM_PARSER_MAP[part_type](part)
# Special case for 'image_url.detail'
if part_type == "image_url" and part.get("detail") != "auto":
logger.warning("'image_url.detail' is currently not supported "
"and will be ignored.")
return part_type, content
# Handle missing 'type' but provided direct URL fields.
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
part)
return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url")
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
role: str, role: str,
...@@ -401,29 +492,28 @@ def _parse_chat_message_content_parts( ...@@ -401,29 +492,28 @@ def _parse_chat_message_content_parts(
has_image = False has_image = False
for part in parts: for part in parts:
part_type = part["type"] if isinstance(part, str): # Handle plain text parts
if part_type == "text": text = _TextParser(part)
text = _TextParser(part)["text"]
texts.append(text) texts.append(text)
elif part_type == "image_url": else: # Handle structured dictionary parts
image_url = _ImageParser(part)["image_url"] part_type, content = _parse_chat_message_content_mm_part(part)
if image_url.get("detail", "auto") != "auto": # if part_type is text/refusal/image_url/audio_url but
logger.warning( # content is empty, logg a warning and skip
"'image_url.detail' is currently not supported and " if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
"will be ignored.") logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
mm_parser.parse_image(image_url["url"]) continue
has_image = True
elif part_type == "audio_url": if part_type in ("text", "refusal"):
audio_url = _AudioParser(part)["audio_url"] texts.append(content)
elif part_type == "image_url":
mm_parser.parse_audio(audio_url["url"]) mm_parser.parse_image(content)
elif part_type == "refusal": has_image = True
text = _RefusalParser(part)["refusal"] elif part_type == "audio_url":
texts.append(text) mm_parser.parse_audio(content)
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
if keep_multimodal_content: if keep_multimodal_content:
...@@ -564,14 +654,14 @@ def apply_mistral_chat_template( ...@@ -564,14 +654,14 @@ def apply_mistral_chat_template(
**kwargs: Any, **kwargs: Any,
) -> List[int]: ) -> List[int]:
if chat_template is not None: if chat_template is not None:
logger.warning( print_warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.") "'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs: if "add_generation_prompt" in kwargs:
logger.warning( print_warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, " "'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.") "so it will be ignored.")
if "continue_final_message" in kwargs: if "continue_final_message" in kwargs:
logger.warning( print_warning_once(
"'continue_final_message' is not supported for mistral tokenizer, " "'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.") "so it will be ignored.")
......
...@@ -6,10 +6,10 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, ...@@ -6,10 +6,10 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
from tqdm import tqdm from tqdm import tqdm
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template, apply_hf_chat_template,
apply_mistral_chat_template, apply_mistral_chat_template,
...@@ -29,7 +29,12 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, ...@@ -29,7 +29,12 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -108,6 +113,12 @@ class LLM: ...@@ -108,6 +113,12 @@ class LLM:
DEPRECATE_LEGACY: ClassVar[bool] = False DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API.""" """A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
"""
A flag to toggle whether to deprecate positional arguments in
:meth:`LLM.__init__`.
"""
@classmethod @classmethod
@contextmanager @contextmanager
def deprecate_legacy_api(cls): def deprecate_legacy_api(cls):
...@@ -117,6 +128,13 @@ class LLM: ...@@ -117,6 +128,13 @@ class LLM:
cls.DEPRECATE_LEGACY = False cls.DEPRECATE_LEGACY = False
@deprecate_args(
start_index=2, # Ignore self and model
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
additional_message=(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."),
)
def __init__( def __init__(
self, self,
model: str, model: str,
...@@ -139,6 +157,8 @@ class LLM: ...@@ -139,6 +157,8 @@ class LLM:
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
**kwargs, **kwargs,
) -> None: ) -> None:
''' '''
...@@ -153,6 +173,7 @@ class LLM: ...@@ -153,6 +173,7 @@ class LLM:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
task=task,
tokenizer=tokenizer, tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init, skip_tokenizer_init=skip_tokenizer_init,
...@@ -316,10 +337,21 @@ class LLM: ...@@ -316,10 +337,21 @@ class LLM:
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter. instead pass them via the ``inputs`` parameter.
""" """
if self.llm_engine.model_config.embedding_mode: task = self.llm_engine.model_config.task
raise ValueError( if task != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation " "LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).") "models (XForCausalLM, XForConditionalGeneration).",
]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "generate" in supported_tasks:
messages.append(
"Your model supports the 'generate' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task generate`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None: if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs( parsed_prompts = self._convert_v1_inputs(
...@@ -433,6 +465,7 @@ class LLM: ...@@ -433,6 +465,7 @@ class LLM:
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence( new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob) logprob_obj.logprob)
...@@ -691,10 +724,18 @@ class LLM: ...@@ -691,10 +724,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter. instead pass them via the ``inputs`` parameter.
""" """
if not self.llm_engine.model_config.embedding_mode: task = self.llm_engine.model_config.task
raise ValueError( if task != "embedding":
"LLM.encode() is only supported for embedding models (XModel)." messages = ["LLM.encode() is only supported for embedding models."]
)
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None: if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs( parsed_prompts = self._convert_v1_inputs(
...@@ -904,6 +945,3 @@ class LLM: ...@@ -904,6 +945,3 @@ class LLM:
def _is_encoder_decoder_model(self): def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model() return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()
...@@ -284,6 +284,12 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -284,6 +284,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."))
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
...@@ -314,9 +320,15 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -314,9 +320,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
prompt_logprobs = self.top_logprobs prompt_logprobs = self.top_logprobs
guided_json_object = None guided_json_object = None
if (self.response_format is not None if self.response_format is not None:
and self.response_format.type == "json_object"): if self.response_format.type == "json_object":
guided_json_object = True guided_json_object = True
elif self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "lm-format-enforcer"
guided_decoding = GuidedDecodingParams.from_optional( guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json, json=self._get_guided_json_from_tool() or self.guided_json,
...@@ -537,8 +549,8 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -537,8 +549,8 @@ class CompletionRequest(OpenAIBaseModel):
default=None, default=None,
description= description=
("Similar to chat completion, this parameter specifies the format of " ("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is " "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
"supported."), "{'type': 'text' } is supported."),
) )
guided_json: Optional[Union[str, dict, BaseModel]] = Field( guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None, default=None,
......
...@@ -38,7 +38,7 @@ from vllm.sequence import Logprob ...@@ -38,7 +38,7 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid from vllm.utils import iterate_with_cancellation
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -176,7 +176,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -176,7 +176,7 @@ class OpenAIServingChat(OpenAIServing):
"\"auto\" tool choice requires " "\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set") "--enable-auto-tool-choice and --tool-call-parser to be set")
request_id = f"chat-{random_uuid()}" request_id = f"chat-{request.request_id}"
request_metadata = RequestResponseMetadata(request_id=request_id) request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request: if raw_request:
......
...@@ -258,6 +258,14 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -258,6 +258,14 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed = [False] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts num_prompt_tokens = [0] * num_prompts
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids prompt_token_ids = res.prompt_token_ids
...@@ -276,28 +284,25 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -276,28 +284,25 @@ class OpenAIServingCompletion(OpenAIServing):
i = output.index + prompt_idx * num_choices i = output.index + prompt_idx * num_choices
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None assert prompt_token_ids is not None
assert prompt_text is not None assert prompt_text is not None
# only return the prompt if request.max_tokens == 0:
delta_text = prompt_text # only return the prompt
delta_token_ids = prompt_token_ids delta_text = prompt_text
out_logprobs = prompt_logprobs delta_token_ids = prompt_token_ids
has_echoed[i] = True out_logprobs = prompt_logprobs
elif (request.echo and request.max_tokens > 0 else:
and not has_echoed[i]): assert prompt_logprobs is not None
assert prompt_token_ids is not None # echo the prompt and first token
assert prompt_text is not None delta_text = prompt_text + output.text
assert prompt_logprobs is not None delta_token_ids = [
# echo the prompt and first token *prompt_token_ids, *output.token_ids
delta_text = prompt_text + output.text ]
delta_token_ids = [ out_logprobs = [
*prompt_token_ids, *output.token_ids *prompt_logprobs,
] *(output.logprobs or []),
out_logprobs = [ ]
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True has_echoed[i] = True
else: else:
# return just the delta # return just the delta
...@@ -341,45 +346,39 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -341,45 +346,39 @@ class OpenAIServingCompletion(OpenAIServing):
stop_reason=stop_reason, stop_reason=stop_reason,
) )
]) ])
if (request.stream_options if include_continuous_usage:
and request.stream_options.include_usage): prompt_tokens = num_prompt_tokens[prompt_idx]
if (request.stream_options.continuous_usage_stats completion_tokens = previous_num_tokens[i]
or output.finish_reason is not None): chunk.usage = UsageInfo(
prompt_tokens = num_prompt_tokens[prompt_idx] prompt_tokens=prompt_tokens,
completion_tokens = previous_num_tokens[i] completion_tokens=completion_tokens,
usage = UsageInfo( total_tokens=prompt_tokens + completion_tokens,
prompt_tokens=prompt_tokens, )
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if request.stream_options.continuous_usage_stats:
chunk.usage = usage
else:
chunk.usage = None
response_json = chunk.model_dump_json(exclude_unset=False) response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if (request.stream_options total_prompt_tokens = sum(num_prompt_tokens)
and request.stream_options.include_usage): total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
if include_usage:
final_usage_chunk = CompletionStreamResponse( final_usage_chunk = CompletionStreamResponse(
id=request_id, id=request_id,
created=created_time, created=created_time,
model=model_name, model=model_name,
choices=[], choices=[],
usage=usage, usage=final_usage_info,
) )
final_usage_data = (final_usage_chunk.model_dump_json( final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True)) exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n" yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices # report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens) request_metadata.final_usage_info = final_usage_info
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
...@@ -413,26 +412,26 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -413,26 +412,26 @@ class OpenAIServingCompletion(OpenAIServing):
for output in final_res.outputs: for output in final_res.outputs:
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo:
assert prompt_text is not None
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
assert prompt_text is not None assert prompt_text is not None
token_ids = [*prompt_token_ids, *output.token_ids] if request.max_tokens == 0:
token_ids = prompt_token_ids
if request.logprobs is None: out_logprobs = prompt_logprobs
out_logprobs = None output_text = prompt_text
else: else:
assert prompt_logprobs is not None token_ids = [*prompt_token_ids, *output.token_ids]
assert output.logprobs is not None
out_logprobs = [ if request.logprobs is None:
*prompt_logprobs, out_logprobs = None
*output.logprobs, else:
] assert prompt_logprobs is not None
assert output.logprobs is not None
output_text = prompt_text + output.text out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else: else:
token_ids = output.token_ids token_ids = output.token_ids
out_logprobs = output.logprobs out_logprobs = output.logprobs
......
...@@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules=None, lora_modules=None,
prompt_adapters=None, prompt_adapters=None,
request_logger=request_logger) request_logger=request_logger)
self._enabled = self._check_embedding_mode(model_config.embedding_mode) self._enabled = self._check_embedding_mode(
model_config.task == "embedding")
async def create_embedding( async def create_embedding(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment