Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
import inspect
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
logger = init_logger(__name__)
def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def support_torch_compile(
cls: Optional[type] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
"""
A decorator to add support for compiling the forward method of a class.
Usage 1: use directly as a decorator without arguments:
```python
@support_torch_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
Usage 2: use as a decorator with arguments:
```python
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
Depending on the value of arguments:
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
of the `forward` method, based on the following default rules:
- if the argument is annotated as `torch.Tensor` or
`Optional[torch.Tensor]`, the first dimension will be
marked as dynamic.
- if the argument is annotated as `IntermediateTensors`, the first
dimension of all the tensors in the intermediate tensors
will be marked as dynamic.
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
......@@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
if not hasattr(cls, 'forward'):
raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward)
for k in dynamic_arg_dims:
inferred_dynamic_arg_dims = dynamic_arg_dims
if inferred_dynamic_arg_dims is None:
inferred_dynamic_arg_dims = {}
for k, v in sig.parameters.items():
if v.annotation in [
torch.Tensor, Optional[torch.Tensor],
IntermediateTensors, Optional[IntermediateTensors]
]:
inferred_dynamic_arg_dims[k] = 0
logger.debug(("Inferred dynamic dimensions for "
"forward method of %s: %s"), cls,
list(inferred_dynamic_arg_dims.keys()))
if len(inferred_dynamic_arg_dims) == 0:
raise ValueError(
"No dynamic dimensions found in the forward method of "
f"{cls}. Please provide dynamic_arg_dims explicitly.")
for k in inferred_dynamic_arg_dims:
if k not in sig.parameters:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}")
return _support_torch_compile(cls, dynamic_arg_dims)
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
if cls is not None:
# use `support_torch_compile` as a decorator without arguments
assert isinstance(cls, type)
return cls_decorator_helper(cls)
return cls_decorator_helper
......
import enum
import json
from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
Optional, Tuple, Type, Union)
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
import torch
from transformers import PretrainedConfig
......@@ -17,8 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once)
is_hip, is_openvino, is_xpu, print_warning_once)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
......@@ -33,6 +32,11 @@ logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding"]
# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
class ModelConfig:
"""Configuration for the model.
......@@ -40,7 +44,11 @@ class ModelConfig:
Args:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
......@@ -108,6 +116,7 @@ class ModelConfig:
def __init__(self,
model: str,
task: Union[TaskOption, _Task],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
......@@ -205,9 +214,15 @@ class ModelConfig:
self.is_attention_free = self._init_attention_free()
self.has_inner_state = self._init_has_inner_state()
self.override_neuron_config = override_neuron_config if is_neuron(
) else None
self._verify_embedding_mode()
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:
self.override_neuron_config = None
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
......@@ -241,18 +256,44 @@ class ModelConfig:
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
def _resolve_task(
self,
task_option: Union[TaskOption, _Task],
hf_config: PretrainedConfig,
) -> Tuple[Set[_Task], _Task]:
if task_option == "draft":
return {"draft"}, "draft"
architectures = getattr(hf_config, "architectures", [])
task_support: Dict[_Task, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
}
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)
if task_option == "auto":
selected_task = next(iter(supported_tasks_lst))
# TODO: Allow the same model architecture to be specified as either
# generation or embedding model
if "Phi3VForCausalLM" in architectures:
# Match both remote and local names
embedding_mode = "/VLM2Vec" in self.model
if len(supported_tasks) > 1:
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
embedding_mode = ModelRegistry.is_embedding_model(architectures)
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)
self.embedding_mode = embedding_mode
selected_task = task_option
return supported_tasks, selected_task
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
......@@ -337,7 +378,7 @@ class ModelConfig:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True
if is_neuron(
if current_platform.is_neuron(
) and self.quantization not in neuron_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
......@@ -410,7 +451,7 @@ class ModelConfig:
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.embedding_mode:
if self.task == "embedding":
self.use_async_output_proc = False
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
......@@ -591,11 +632,6 @@ class ModelConfig:
(hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False)))
@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
......@@ -952,6 +988,7 @@ class SchedulerConfig:
"""Scheduler configuration.
Args:
task: The task to use the model for.
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
......@@ -966,7 +1003,6 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
......@@ -981,13 +1017,13 @@ class SchedulerConfig:
"""
def __init__(self,
task: _Task,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: bool = False,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1,
......@@ -1011,7 +1047,7 @@ class SchedulerConfig:
# for higher throughput.
max_num_batched_tokens = max(max_model_len, 2048)
if embedding_mode:
if task == "embedding":
# For embedding, choose specific value for higher throughput
max_num_batched_tokens = max(
max_num_batched_tokens,
......@@ -1031,12 +1067,12 @@ class SchedulerConfig:
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
self.max_num_batched_tokens)
self.task: Final = task
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
......@@ -1086,7 +1122,7 @@ class DeviceConfig:
# Automated device type detection
if current_platform.is_cuda_alike():
self.device_type = "cuda"
elif is_neuron():
elif current_platform.is_neuron():
self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"
......@@ -1248,6 +1284,7 @@ class SpeculativeConfig:
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
task="draft",
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
......@@ -1381,11 +1418,11 @@ class SpeculativeConfig:
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
elif speculative_draft_tensor_parallel_size not in (
1, target_parallel_config.tensor_parallel_size):
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1")
f"other value than 1 or target model tensor_parallel_size")
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
......
......@@ -7,7 +7,7 @@ from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
PrefixHash = int
......
import enum
from abc import ABC, abstractmethod
from typing import OrderedDict
from vllm.block import PhysicalTokenBlock
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_hash: int) -> bool:
pass
@abstractmethod
def evict(self) -> PhysicalTokenBlock:
"""Runs the eviction algorithm and returns the evicted block"""
pass
@abstractmethod
def add(self, block: PhysicalTokenBlock):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def remove(self, block_hash: int) -> PhysicalTokenBlock:
"""Simply removes the block with the hash value block_hash from the
evictor. Caller is responsible for making sure that block_hash is
contained in the evictor before calling remove. Should be used to
"bring back" blocks that have been freed but not evicted yet.
"""
pass
@property
@abstractmethod
def num_blocks(self) -> int:
pass
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def __init__(self):
self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict()
def __contains__(self, block_hash: int) -> bool:
return block_hash in self.free_table
def evict(self) -> PhysicalTokenBlock:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _, block in self.free_table.items():
if evicted_block.last_accessed < block.last_accessed:
break
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block = block
self.free_table.pop(evicted_block.block_hash)
evicted_block.computed = False
return evicted_block
def add(self, block: PhysicalTokenBlock):
self.free_table[block.block_hash] = block
def remove(self, block_hash: int) -> PhysicalTokenBlock:
if block_hash not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
block: PhysicalTokenBlock = self.free_table[block_hash]
self.free_table.pop(block_hash)
return block
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
......@@ -313,7 +313,7 @@ class Scheduler:
self.lora_config = lora_config
version = "selfattn"
if (self.scheduler_config.embedding_mode
if (self.scheduler_config.task == "embedding"
or self.cache_config.is_attention_free):
version = "placeholder"
......
......@@ -7,7 +7,7 @@ It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
- any code dealing with the distributed stuff
......@@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
steps.
"""
import contextlib
import gc
import pickle
import weakref
from collections import namedtuple
......@@ -1129,6 +1130,19 @@ def destroy_distributed_environment():
torch.distributed.destroy_process_group()
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
if shutdown_ray:
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
if not current_platform.is_cpu():
torch.cuda.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
"""
This is a collective operation that returns if each rank is in the same node
......
......@@ -3,7 +3,7 @@ import dataclasses
import json
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast)
Tuple, Type, Union, cast, get_args)
import torch
......@@ -12,10 +12,12 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import FlexibleArgumentParser
......@@ -84,6 +86,7 @@ class EngineArgs:
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
......@@ -198,6 +201,15 @@ class EngineArgs:
type=str,
default=EngineArgs.model,
help='Name or path of the huggingface model to use.')
parser.add_argument(
'--task',
default=EngineArgs.task,
choices=get_args(TaskOption),
help='The task to use the model for. Each vLLM instance only '
'supports one task, even if the same model can be used for '
'multiple tasks. When the model only supports one task, "auto" '
'can be used to select it; otherwise, you must specify explicitly '
'which task to use.')
parser.add_argument(
'--tokenizer',
type=nullable_str,
......@@ -418,7 +430,11 @@ class EngineArgs:
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9.')
'will use the default value of 0.9. This is a global gpu memory '
'utilization limit, for example if 50%% of the gpu memory is '
'already used before vLLM starts and --gpu-memory-utilization is '
'set to 0.9, then only 40%% of the gpu memory will be allocated '
'to the model executor.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
......@@ -838,6 +854,7 @@ class EngineArgs:
def create_model_config(self) -> ModelConfig:
return ModelConfig(
model=self.model,
task=self.task,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
......@@ -909,6 +926,8 @@ class EngineArgs:
"supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False
maybe_register_config_serialize_by_value(self.trust_remote_code)
cache_config = CacheConfig(
# neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else
......@@ -1026,13 +1045,13 @@ class EngineArgs:
" please file an issue with detailed information.")
scheduler_config = SchedulerConfig(
task=model_config.task,
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
......
import time
from collections import Counter as collectionsCounter
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -43,8 +44,10 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceGroupOutput, SequenceStatus)
ParallelSampleSequenceGroup, Sequence,
SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
......@@ -344,7 +347,7 @@ class LLMEngine:
observability_config=self.observability_config,
)
if not self.model_config.embedding_mode:
if self.model_config.task != "embedding":
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
......@@ -473,6 +476,8 @@ class LLMEngine:
),
))
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
......@@ -641,7 +646,10 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> None:
) -> SequenceGroup:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
......@@ -695,6 +703,8 @@ class LLMEngine:
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
return seq_group
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
......@@ -710,7 +720,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
) -> Optional[SequenceGroup]:
...
@overload
......@@ -724,7 +734,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
) -> Optional[SequenceGroup]:
...
@deprecate_kwargs(
......@@ -743,7 +753,7 @@ class LLMEngine:
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
) -> Optional[SequenceGroup]:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
......@@ -787,6 +797,22 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
prompt=prompt,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
inputs=inputs,
)
return None
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
......@@ -817,7 +843,7 @@ class LLMEngine:
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")
self._add_processed_request(
return self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
......@@ -1116,7 +1142,7 @@ class LLMEngine:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.embedding_mode:
if self.model_config.task == "embedding":
self._process_sequence_group_outputs(seq_group, output)
else:
self.output_processor.process_prompt_logprob(seq_group, output)
......@@ -1134,7 +1160,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)
......@@ -1174,7 +1202,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)
......@@ -1193,7 +1223,10 @@ class LLMEngine:
continue
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs,
)
if request_output:
ctx.request_outputs.append(request_output)
......@@ -1212,7 +1245,7 @@ class LLMEngine:
skip)
# Tracing
self.do_tracing(scheduler_outputs)
self.do_tracing(scheduler_outputs, finished_before)
return None
......@@ -1617,6 +1650,25 @@ class LLMEngine:
n_requests: List[int] = []
finished_reason_requests: List[str] = []
# Lora requests
running_lora_adapters = dict(
collectionsCounter([
running_request.lora_request.lora_name
for scheduler in self.scheduler
for running_request in scheduler.running
if running_request.lora_request
]))
waiting_lora_adapters = dict(
collectionsCounter([
waiting_request.lora_request.lora_name
for scheduler in self.scheduler
for waiting_request in scheduler.waiting
if waiting_request.lora_request
]))
max_lora_stat = "0"
if self.lora_config:
max_lora_stat = str(self.lora_config.max_loras)
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
if scheduler_outputs is not None:
......@@ -1666,6 +1718,15 @@ class LLMEngine:
# TPOTs.
latency = seq_group.get_last_latency(now)
time_per_output_tokens_iter.append(latency)
if seq_group.state.current_step == 0:
# For async_output_proc, the do_log_stats()
# is called following init_multi_step(), which
# sets the current_step to zero.
actual_num_batched_tokens +=\
seq_group.state.num_steps - 1
else:
actual_num_batched_tokens +=\
seq_group.state.current_step - 1
# Because of chunked prefill, we can have a single sequence
# group that does multiple prompt_runs. To prevent logging
......@@ -1738,7 +1799,9 @@ class LLMEngine:
num_generation_tokens_requests=num_generation_tokens_requests,
n_requests=n_requests,
finished_reason_requests=finished_reason_requests,
)
max_lora=str(max_lora_stat),
waiting_lora_adapters=list(waiting_lora_adapters.keys()),
running_lora_adapters=list(running_lora_adapters.keys()))
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
......@@ -1786,11 +1849,18 @@ class LLMEngine:
def is_tracing_enabled(self) -> bool:
return self.tracer is not None
def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
def do_tracing(self,
scheduler_outputs: SchedulerOutputs,
finished_before: Optional[List[int]] = None) -> None:
if self.tracer is None:
return
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
for idx, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
# Skip double tracing when using async output proc
if finished_before and idx in finished_before:
continue
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
self.create_trace_span(seq_group)
......@@ -1855,9 +1925,6 @@ class LLMEngine:
def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()
def is_embedding_model(self):
return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
if self.model_config.is_multimodal_model:
......
......@@ -34,7 +34,11 @@ class Metrics:
See https://prometheus.github.io/client_python/multiprocess/ for more
details on limitations.
"""
labelname_finish_reason = "finished_reason"
labelname_waiting_lora_adapters = "waiting_lora_adapters"
labelname_running_lora_adapters = "running_lora_adapters"
labelname_max_lora = "max_lora"
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
......@@ -55,6 +59,16 @@ class Metrics:
documentation="Number of requests waiting to be processed.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_lora_info = self._gauge_cls(
name="vllm:lora_requests_info",
documentation="Running stats on lora requests.",
labelnames=[
self.labelname_running_lora_adapters,
self.labelname_max_lora,
self.labelname_waiting_lora_adapters,
],
multiprocess_mode="livemostrecent",
)
self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
......@@ -426,6 +440,9 @@ class PrometheusStatLogger(StatLoggerBase):
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None:
gauge.labels(**data).set(1)
def _log_prometheus(self, stats: Stats) -> None:
# System state data
self._log_gauge(self.metrics.gauge_scheduler_running,
......@@ -442,7 +459,17 @@ class PrometheusStatLogger(StatLoggerBase):
stats.cpu_prefix_cache_hit_rate)
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
stats.gpu_prefix_cache_hit_rate)
# Including max-lora in metric, in future this property of lora
# config maybe extended to be dynamic.
lora_info = {
self.metrics.labelname_running_lora_adapters:
",".join(stats.running_lora_adapters),
self.metrics.labelname_waiting_lora_adapters:
",".join(stats.waiting_lora_adapters),
self.metrics.labelname_max_lora:
stats.max_lora,
}
self._log_gauge_string(self.metrics.gauge_lora_info, lora_info)
# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,
stats.num_preemption_iter)
......
......@@ -51,6 +51,9 @@ class Stats:
num_generation_tokens_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
waiting_lora_adapters: List[str]
running_lora_adapters: List[str]
max_lora: str
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
......
......@@ -204,8 +204,20 @@ class MQLLMEngineClient(EngineClient):
# (and record only the first one)
if is_engine_errored and not self._errored_with:
self._errored_with = exception
# If engine is errored, no matter the type of exception
# it will no longer be able to receive new requests,
# therefore we have to inform that the current
# processed requests failed as well. Send back a dead
# engine error give this feedback and also give a
# 'hint' to the server to shutdown next.
exception = self.dead_error
if request_id is None:
# If request_id is None, then the engine raised an
# exception for a batch, and we may not know the
# request that caused it, neither if it was actually
# caused by any of them (e.g. CUDA OOM). Therefore we
# broadcast the same exception for all requests.
for queue_i in tuple(self.output_queues.values()):
queue_i.put_nowait(exception)
else:
......
......@@ -8,7 +8,7 @@ from typing import Iterator, List, Optional, Union
import cloudpickle
import zmq
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
......@@ -21,12 +21,17 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
if VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine
else:
from vllm.engine.llm_engine import LLMEngine
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
......@@ -136,14 +141,16 @@ class MQLLMEngine:
executor_class = LLMEngine._get_executor_cls(engine_config)
return cls(
ipc_path=ipc_path,
use_async_sockets=engine_config.model_config.use_async_output_proc,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
use_async_sockets = (engine_config.model_config.use_async_output_proc
and not VLLM_USE_V1)
return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
def start(self):
try:
......
......@@ -59,7 +59,7 @@ class EngineClient(ABC):
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
prompt: Union[str, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
......@@ -71,9 +71,13 @@ class EngineClient(ABC):
length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
if isinstance(prompt, str):
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
......@@ -81,7 +85,11 @@ class EngineClient(ABC):
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
logprobs=[],
cum_logprob=0)
]
completed = []
for _ in range(max_tokens):
......@@ -114,6 +122,7 @@ class EngineClient(ABC):
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
......@@ -131,22 +140,22 @@ class EngineClient(ABC):
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam.text = tokenizer.decode(beam.tokens[tokenized_length:])
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.cum_logprob,
logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_token_ids=tokenized_prompt,
prompt_logprobs=None)
yield beam_search_output
......
......@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
Mapping, Optional, Tuple, TypeVar, Union, cast)
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
# yapf conflicts with isort for this block
# yapf: disable
......@@ -33,6 +33,7 @@ from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once
logger = init_logger(__name__)
......@@ -58,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
}
"""
image_url: Required[str]
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
}
"""
audio_url: Required[str]
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam]
CustomChatCompletionContentPartParam,
CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam, str]
class CustomChatCompletionMessageParam(TypedDict, total=False):
......@@ -386,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
"""
Parses a given multi modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
Returns:
A tuple (part_type, content) where:
- part_type: Type of the part (e.g., 'text', 'image_url').
- content: Parsed content (e.g., text, image URL).
Raises:
ValueError: If the 'type' field is missing and no direct URL is found.
"""
assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
content = MM_PARSER_MAP[part_type](part)
# Special case for 'image_url.detail'
if part_type == "image_url" and part.get("detail") != "auto":
logger.warning("'image_url.detail' is currently not supported "
"and will be ignored.")
return part_type, content
# Handle missing 'type' but provided direct URL fields.
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
part)
return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url")
def _parse_chat_message_content_parts(
role: str,
......@@ -401,29 +492,28 @@ def _parse_chat_message_content_parts(
has_image = False
for part in parts:
part_type = part["type"]
if part_type == "text":
text = _TextParser(part)["text"]
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
texts.append(text)
elif part_type == "image_url":
image_url = _ImageParser(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
mm_parser.parse_image(image_url["url"])
has_image = True
elif part_type == "audio_url":
audio_url = _AudioParser(part)["audio_url"]
mm_parser.parse_audio(audio_url["url"])
elif part_type == "refusal":
text = _RefusalParser(part)["refusal"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url but
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
continue
if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if keep_multimodal_content:
......@@ -564,14 +654,14 @@ def apply_mistral_chat_template(
**kwargs: Any,
) -> List[int]:
if chat_template is not None:
logger.warning(
print_warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning(
print_warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning(
print_warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
......
......@@ -6,10 +6,10 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
from tqdm import tqdm
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
......@@ -29,7 +29,12 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore
logger = init_logger(__name__)
......@@ -108,6 +113,12 @@ class LLM:
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
"""
A flag to toggle whether to deprecate positional arguments in
:meth:`LLM.__init__`.
"""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
......@@ -117,6 +128,13 @@ class LLM:
cls.DEPRECATE_LEGACY = False
@deprecate_args(
start_index=2, # Ignore self and model
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
additional_message=(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."),
)
def __init__(
self,
model: str,
......@@ -139,6 +157,8 @@ class LLM:
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
**kwargs,
) -> None:
'''
......@@ -153,6 +173,7 @@ class LLM:
engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
......@@ -316,10 +337,21 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
task = self.llm_engine.model_config.task
if task != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).")
"models (XForCausalLM, XForConditionalGeneration).",
]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "generate" in supported_tasks:
messages.append(
"Your model supports the 'generate' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task generate`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
......@@ -433,6 +465,7 @@ class LLM:
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
......@@ -691,10 +724,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.encode() is only supported for embedding models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
......@@ -904,6 +945,3 @@ class LLM:
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()
......@@ -284,6 +284,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
# doc: end-chat-completion-extra-params
......@@ -314,9 +320,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
prompt_logprobs = self.top_logprobs
guided_json_object = None
if (self.response_format is not None
and self.response_format.type == "json_object"):
guided_json_object = True
if self.response_format is not None:
if self.response_format.type == "json_object":
guided_json_object = True
elif self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "lm-format-enforcer"
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
......@@ -537,8 +549,8 @@ class CompletionRequest(OpenAIBaseModel):
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
"{'type': 'text' } is supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
......
......@@ -38,7 +38,7 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid
from vllm.utils import iterate_with_cancellation
logger = init_logger(__name__)
......@@ -176,7 +176,7 @@ class OpenAIServingChat(OpenAIServing):
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")
request_id = f"chat-{random_uuid()}"
request_id = f"chat-{request.request_id}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
......
......@@ -258,6 +258,14 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
......@@ -276,28 +284,25 @@ class OpenAIServingCompletion(OpenAIServing):
i = output.index + prompt_idx * num_choices
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
assert prompt_text is not None
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
]
out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
if request.max_tokens == 0:
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
else:
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
]
out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True
else:
# return just the delta
......@@ -341,45 +346,39 @@ class OpenAIServingCompletion(OpenAIServing):
stop_reason=stop_reason,
)
])
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if request.stream_options.continuous_usage_stats:
chunk.usage = usage
else:
chunk.usage = None
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"
if (request.stream_options
and request.stream_options.include_usage):
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=usage,
usage=final_usage_info,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
request_metadata.final_usage_info = final_usage_info
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
......@@ -413,26 +412,26 @@ class OpenAIServingCompletion(OpenAIServing):
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
if request.echo:
assert prompt_text is not None
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
out_logprobs = output.logprobs
......
......@@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._enabled = self._check_embedding_mode(model_config.embedding_mode)
self._enabled = self._check_embedding_mode(
model_config.task == "embedding")
async def create_embedding(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment