Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
...@@ -355,6 +355,7 @@ class StatLoggerBase(ABC): ...@@ -355,6 +355,7 @@ class StatLoggerBase(ABC):
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []
self.last_local_log = time.time() self.last_local_log = time.time()
self.local_interval = local_interval self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@abstractmethod @abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
...@@ -364,6 +365,12 @@ class StatLoggerBase(ABC): ...@@ -364,6 +365,12 @@ class StatLoggerBase(ABC):
def log(self, stats: Stats) -> None: def log(self, stats: Stats) -> None:
raise NotImplementedError raise NotImplementedError
def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics
class LoggingStatLogger(StatLoggerBase): class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout.""" """LoggingStatLogger is used in LLMEngine to log to Stdout."""
...@@ -379,6 +386,9 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -379,6 +386,9 @@ class LoggingStatLogger(StatLoggerBase):
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log, if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval): self.local_interval):
...@@ -408,15 +418,16 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -408,15 +418,16 @@ class LoggingStatLogger(StatLoggerBase):
stats.cpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100,
) )
if self.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
self.spec_decode_metrics))
# Reset tracked stats for next interval. # Reset tracked stats for next interval.
self.num_prompt_tokens = [] self.num_prompt_tokens = []
self.num_generation_tokens = [] self.num_generation_tokens = []
self.last_local_log = stats.now self.last_local_log = stats.now
self.spec_decode_metrics = None
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str( def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str: self, metrics: "SpecDecodeWorkerMetrics") -> str:
...@@ -533,6 +544,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -533,6 +544,9 @@ class PrometheusStatLogger(StatLoggerBase):
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log, if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval): self.local_interval):
...@@ -550,26 +564,27 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -550,26 +564,27 @@ class PrometheusStatLogger(StatLoggerBase):
prompt_throughput=prompt_throughput, prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput) generation_throughput=generation_throughput)
# Reset tracked stats for next interval. if self.spec_decode_metrics is not None:
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
self._log_gauge( self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate, self.metrics.gauge_spec_decode_draft_acceptance_rate,
stats.spec_decode_metrics.draft_acceptance_rate) self.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency, self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
stats.spec_decode_metrics.system_efficiency) self.spec_decode_metrics.system_efficiency)
self._log_counter( self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens, self.metrics.counter_spec_decode_num_accepted_tokens,
stats.spec_decode_metrics.accepted_tokens) self.spec_decode_metrics.accepted_tokens)
self._log_counter( self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens, self.metrics.counter_spec_decode_num_draft_tokens,
stats.spec_decode_metrics.draft_tokens) self.spec_decode_metrics.draft_tokens)
self._log_counter( self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens, self.metrics.counter_spec_decode_num_emitted_tokens,
stats.spec_decode_metrics.emitted_tokens) self.spec_decode_metrics.emitted_tokens)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
self.spec_decode_metrics = None
class RayPrometheusStatLogger(PrometheusStatLogger): class RayPrometheusStatLogger(PrometheusStatLogger):
......
...@@ -81,6 +81,29 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -81,6 +81,29 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None: outputs: SequenceGroupOutput) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.n == 1 and not sampling_params.use_beam_search:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
sampling_params,
lora_req=seq_group.lora_request,
)
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
# Process samples # Process samples
samples = outputs.samples samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
...@@ -127,20 +150,20 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -127,20 +150,20 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
child_seqs.append((parent, parent)) child_seqs.append((parent, parent))
for seq, _ in child_seqs: for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize and self.detokenizer: if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace( new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params) seq, sampling_params)
else: else:
new_char_count = 0 new_char_count = 0
self.stop_checker.maybe_stop_sequence( self.stop_checker.maybe_stop_sequence(
seq, seq,
new_char_count, new_char_count,
seq_group.sampling_params, sampling_params,
lora_req=seq_group.lora_request, lora_req=seq_group.lora_request,
) )
# Non-beam search case # Non-beam search case
if not seq_group.sampling_params.use_beam_search: if not sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group # For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished. # and fork them in block manager if they are not finished.
for seq, parent in child_seqs: for seq, parent in child_seqs:
...@@ -164,8 +187,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -164,8 +187,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Select the child sequences to keep in the sequence group. # Select the child sequences to keep in the sequence group.
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = seq_group.sampling_params.best_of beam_width = sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty length_penalty = sampling_params.length_penalty
# Select the newly finished sequences with the highest scores # Select the newly finished sequences with the highest scores
# to replace existing finished sequences. # to replace existing finished sequences.
...@@ -219,8 +242,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -219,8 +242,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
best_running_seq = running_child_seqs[0][0] best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0] current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping( stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping, sampling_params.early_stopping, sampling_params,
seq_group.sampling_params, best_running_seq, current_worst_seq) best_running_seq, current_worst_seq)
if stop_beam_search: if stop_beam_search:
# Stop the beam search and remove all the running sequences from # Stop the beam search and remove all the running sequences from
......
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
runtime_checkable)
from transformers import PreTrainedTokenizer
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
@runtime_checkable
class AsyncEngineClient(Protocol):
"""Protocol class for Clients to AsyncLLMEngine"""
@property
def is_running(self) -> bool:
...
@property
def is_stopped(self) -> bool:
...
@property
def errored(self) -> bool:
...
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generates outputs for a request"""
async def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model."""
async def abort(self, request_id: str) -> None:
"""Abort a request.
Args:
request_id: The unique id of the request.
"""
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer:
"""Get the appropriate Tokenizer for the request"""
async def is_tracing_enabled(self) -> bool:
pass
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass
async def check_health(self) -> None:
"""Raise if unhealthy"""
...@@ -5,21 +5,23 @@ For production use, we recommend using our OpenAI compatible server. ...@@ -5,21 +5,23 @@ For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead. change `vllm/entrypoints/openai/api_server.py` instead.
""" """
import asyncio
import json import json
import ssl import ssl
from typing import AsyncGenerator from argparse import Namespace
from typing import Any, AsyncGenerator, Optional
import uvicorn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server") logger = init_logger("vllm.entrypoints.api_server")
...@@ -81,6 +83,53 @@ async def generate(request: Request) -> Response: ...@@ -81,6 +83,53 @@ async def generate(request: Request) -> Response:
return JSONResponse(ret) return JSONResponse(ret)
def build_app(args: Namespace) -> FastAPI:
global app
app.root_path = args.root_path
return app
async def init_app(
args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER))
return app
async def run_server(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs: Any) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
app = await init_app(args, llm_engine)
shutdown_task = await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
await shutdown_task
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None) parser.add_argument("--host", type=str, default=None)
...@@ -105,25 +154,5 @@ if __name__ == "__main__": ...@@ -105,25 +154,5 @@ if __name__ == "__main__":
parser.add_argument("--log-level", type=str, default="debug") parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER)
app.root_path = args.root_path
logger.info("Available routes are:") asyncio.run(run_server(args))
for route in app.routes:
if not hasattr(route, 'methods'):
continue
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)
uvicorn.run(app,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
import codecs import codecs
from dataclasses import dataclass, field from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, Union, cast, final from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict): ...@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True) @dataclass(frozen=True)
class ChatMessageParseResult: class ChatMessageParseResult:
messages: List[ConversationMessage] messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field( mm_futures: List[Awaitable[MultiModalDataDict]]
default_factory=list)
def load_chat_template(chat_template: Optional[str]) -> Optional[str]: def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
...@@ -100,14 +100,16 @@ def _image_token_str(model_config: ModelConfig, ...@@ -100,14 +100,16 @@ def _image_token_str(model_config: ModelConfig,
if model_type == "phi3_v": if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer # Workaround since this token is not defined in the tokenizer
return "<|image_1|>" return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"): if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt # These models do not use image tokens in the prompt
return None return None
if model_type.startswith("llava"): if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index) return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type == "chameleon": if model_type in ("chameleon", "internvl_chat"):
return "<image>" return "<image>"
raise TypeError("Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
# TODO: Let user specify how to insert image tokens into prompt # TODO: Let user specify how to insert image tokens into prompt
...@@ -172,7 +174,7 @@ def _parse_chat_message_content_parts( ...@@ -172,7 +174,7 @@ def _parse_chat_message_content_parts(
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def parse_chat_message_content( def _parse_chat_message_content(
message: ChatCompletionMessageParam, message: ChatCompletionMessageParam,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
...@@ -188,3 +190,21 @@ def parse_chat_message_content( ...@@ -188,3 +190,21 @@ def parse_chat_message_content(
return _parse_chat_message_content_parts(role, content, model_config, return _parse_chat_message_content_parts(role, content, model_config,
tokenizer) tokenizer)
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures
import asyncio
import signal
from typing import Any
import uvicorn
from fastapi import FastAPI
from vllm.logger import init_logger
logger = init_logger(__name__)
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
path = getattr(route, "path", None)
if methods is None or path is None:
continue
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.serve())
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
async def dummy_shutdown() -> None:
pass
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
return dummy_shutdown()
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
return server.shutdown()
...@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, ...@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt) parse_and_batch_prompt)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -262,6 +265,8 @@ class LLM: ...@@ -262,6 +265,8 @@ class LLM:
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -303,6 +308,14 @@ class LLM: ...@@ -303,6 +308,14 @@ class LLM:
else: else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
raise ValueError(
"You can only use one guided decoding but multiple is "
f"specified: {guided_options_request}")
guided_options_request = GuidedDecodingRequest(
**guided_options_request)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
...@@ -311,7 +324,8 @@ class LLM: ...@@ -311,7 +324,8 @@ class LLM:
inputs=inputs, inputs=inputs,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request)
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput) return LLMEngine.validate_outputs(outputs, RequestOutput)
...@@ -508,6 +522,7 @@ class LLM: ...@@ -508,6 +522,7 @@ class LLM:
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
...@@ -523,6 +538,15 @@ class LLM: ...@@ -523,6 +538,15 @@ class LLM:
raise ValueError("The lengths of prompts and lora_request " raise ValueError("The lengths of prompts and lora_request "
"must be the same.") "must be the same.")
if isinstance(params, list):
params = [
self._add_guided_processor(param, guided_options)
if isinstance(param, SamplingParams) else param
for param in params
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
# Add requests to the engine. # Add requests to the engine.
for i, request_inputs in enumerate(inputs): for i, request_inputs in enumerate(inputs):
self._add_request( self._add_request(
...@@ -548,6 +572,24 @@ class LLM: ...@@ -548,6 +572,24 @@ class LLM:
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
def _add_guided_processor(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer())
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
return params
def _run_engine( def _run_engine(
self, *, use_tqdm: bool self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
......
...@@ -2,13 +2,13 @@ import asyncio ...@@ -2,13 +2,13 @@ import asyncio
import importlib import importlib
import inspect import inspect
import re import re
from argparse import Namespace
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Set from multiprocessing import Process
from typing import AsyncIterator, Set
import fastapi from fastapi import APIRouter, FastAPI, Request
import uvicorn
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
...@@ -16,8 +16,11 @@ from prometheus_client import make_asgi_app ...@@ -16,8 +16,11 @@ from prometheus_client import make_asgi_app
from starlette.routing import Mount from starlette.routing import Mount
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
...@@ -30,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -30,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest, ErrorResponse, EmbeddingRequest, ErrorResponse,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
...@@ -38,12 +43,12 @@ from vllm.entrypoints.openai.serving_tokenization import ( ...@@ -38,12 +43,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
engine: AsyncLLMEngine async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion openai_serving_completion: OpenAIServingCompletion
...@@ -55,13 +60,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server') ...@@ -55,13 +60,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set() _running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
seed=0,
dtype="float16").embedding_mode
@asynccontextmanager @asynccontextmanager
async def lifespan(app: fastapi.FastAPI): async def lifespan(app: FastAPI):
async def _force_log(): async def _force_log():
while True: while True:
await asyncio.sleep(10) await asyncio.sleep(10)
await engine.do_log_stats() await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats: if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log()) task = asyncio.create_task(_force_log())
...@@ -71,10 +85,56 @@ async def lifespan(app: fastapi.FastAPI): ...@@ -71,10 +85,56 @@ async def lifespan(app: fastapi.FastAPI):
yield yield
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port))
rpc_server_process.start()
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port)
await async_engine_client.setup()
try:
yield async_engine_client
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
# Wait for server process to join
rpc_server_process.join()
router = APIRouter() router = APIRouter()
def mount_metrics(app: fastapi.FastAPI): def mount_metrics(app: FastAPI):
# Add prometheus asgi middleware to route /metrics requests # Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app()) metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics # Workaround for 307 Redirect for /metrics
...@@ -85,7 +145,7 @@ def mount_metrics(app: fastapi.FastAPI): ...@@ -85,7 +145,7 @@ def mount_metrics(app: fastapi.FastAPI):
@router.get("/health") @router.get("/health")
async def health() -> Response: async def health() -> Response:
"""Health check.""" """Health check."""
await openai_serving_chat.engine.check_health() await async_engine_client.check_health()
return Response(status_code=200) return Response(status_code=200)
...@@ -164,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): ...@@ -164,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
def build_app(args): def build_app(args: Namespace) -> FastAPI:
app = fastapi.FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.include_router(router) app.include_router(router)
app.root_path = args.root_path app.root_path = args.root_path
...@@ -213,37 +273,18 @@ def build_app(args): ...@@ -213,37 +273,18 @@ def build_app(args):
return app return app
def run_server(args, llm_engine=None): async def init_app(
async_engine_client: AsyncEngineClient,
args: Namespace,
) -> FastAPI:
app = build_app(args) app = build_app(args)
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model_names = [args.model] served_model_names = [args.model]
global engine, engine_args model_config = await async_engine_client.get_model_config()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
event_loop: Optional[asyncio.AbstractEventLoop]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config = event_loop.run_until_complete(engine.get_model_config())
else:
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
if args.disable_log_requests: if args.disable_log_requests:
request_logger = None request_logger = None
...@@ -256,7 +297,7 @@ def run_server(args, llm_engine=None): ...@@ -256,7 +297,7 @@ def run_server(args, llm_engine=None):
global openai_serving_tokenization global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat( openai_serving_chat = OpenAIServingChat(
engine, async_engine_client,
model_config, model_config,
served_model_names, served_model_names,
args.response_role, args.response_role,
...@@ -264,23 +305,25 @@ def run_server(args, llm_engine=None): ...@@ -264,23 +305,25 @@ def run_server(args, llm_engine=None):
prompt_adapters=args.prompt_adapters, prompt_adapters=args.prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
chat_template=args.chat_template, chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) )
openai_serving_completion = OpenAIServingCompletion( openai_serving_completion = OpenAIServingCompletion(
engine, async_engine_client,
model_config, model_config,
served_model_names, served_model_names,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters, prompt_adapters=args.prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) )
openai_serving_embedding = OpenAIServingEmbedding( openai_serving_embedding = OpenAIServingEmbedding(
engine, async_engine_client,
model_config, model_config,
served_model_names, served_model_names,
request_logger=request_logger, request_logger=request_logger,
) )
openai_serving_tokenization = OpenAIServingTokenization( openai_serving_tokenization = OpenAIServingTokenization(
engine, async_engine_client,
model_config, model_config,
served_model_names, served_model_names,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
...@@ -289,22 +332,31 @@ def run_server(args, llm_engine=None): ...@@ -289,22 +332,31 @@ def run_server(args, llm_engine=None):
) )
app.root_path = args.root_path app.root_path = args.root_path
logger.info("Available routes are:") return app
for route in app.routes:
if not hasattr(route, 'methods'):
continue async def run_server(args, **uvicorn_kwargs) -> None:
methods = ', '.join(route.methods) logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("Route: %s, Methods: %s", route.path, methods) logger.info("args: %s", args)
uvicorn.run(app, async with build_async_engine_client(args) as async_engine_client:
host=args.host, app = await init_app(async_engine_client, args)
port=args.port,
log_level=args.uvicorn_log_level, shutdown_task = await serve_http(
timeout_keep_alive=TIMEOUT_KEEP_ALIVE, app,
ssl_keyfile=args.ssl_keyfile, host=args.host,
ssl_certfile=args.ssl_certfile, port=args.port,
ssl_ca_certs=args.ssl_ca_certs, log_level=args.uvicorn_log_level,
ssl_cert_reqs=args.ssl_cert_reqs) timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
if __name__ == "__main__": if __name__ == "__main__":
...@@ -314,4 +366,5 @@ if __name__ == "__main__": ...@@ -314,4 +366,5 @@ if __name__ == "__main__":
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
args = parser.parse_args() args = parser.parse_args()
run_server(args)
asyncio.run(run_server(args))
...@@ -128,6 +128,17 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -128,6 +128,17 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using @app.middleware('http'). " "using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server " "If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ") "using app.add_middleware(). ")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as "
"strings of the form 'token_id:{token_id}' so that tokens that "
"are not JSON-encodable can be identified.")
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
......
from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer
from vllm.sampling_params import LogitsProcessor
class AllowedTokenIdsLogitsProcessor:
"""Logits processor for constraining generated tokens to a
specific set of token ids."""
def __init__(self, allowed_ids: Iterable[int]):
self.allowed_ids: Optional[List[int]] = list(allowed_ids)
self.mask: Optional[torch.Tensor] = None
def __call__(self, token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
if self.mask is None:
self.mask = torch.ones((logits.shape[-1], ),
dtype=torch.bool,
device=logits.device)
self.mask[self.allowed_ids] = False
self.allowed_ids = None
logits.masked_fill_(self.mask, float("-inf"))
return logits
@lru_cache(maxsize=32)
def _get_allowed_token_ids_logits_processor(
allowed_token_ids: FrozenSet[int],
vocab_size: int,
) -> LogitsProcessor:
if not allowed_token_ids:
raise ValueError("Empty allowed_token_ids provided")
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
raise ValueError("allowed_token_ids contains "
"out-of-vocab token id")
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(logit_bias: Dict[str,
float], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
logits_processors = []
if logit_bias:
try:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias: Dict[int, float] = {
int(token_id): min(100.0, max(-100.0, bias))
for token_id, bias in logit_bias.items()
}
except ValueError as exc:
raise ValueError(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer") from exc
# Check if token_id is within the vocab size
for token_id, bias in clamped_logit_bias.items():
if token_id < 0 or token_id >= tokenizer.vocab_size:
raise ValueError("token_id in logit_bias contains "
"out-of-vocab token id")
logits_processors.append(
partial(logit_bias_logits_processor, clamped_logit_bias))
if allowed_token_ids is not None:
logits_processors.append(
_get_allowed_token_ids_logits_processor(
frozenset(allowed_token_ids), tokenizer.vocab_size))
return logits_processors
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
try:
from sphinx.ext.autodoc.mock import _MockModule
if isinstance(torch, _MockModule):
_LONG_INFO = _MOCK_LONG_INFO
else:
_LONG_INFO = torch.iinfo(torch.long)
except ModuleNotFoundError:
_LONG_INFO = torch.iinfo(torch.long)
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class OpenAIBaseModel(BaseModel): class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields # OpenAI API does not allow extra fields
...@@ -106,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -106,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
n: Optional[int] = 1 n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
seed: Optional[int] = Field(None, seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
...@@ -213,30 +231,22 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -213,30 +231,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_sampling_params(self) -> SamplingParams: def to_sampling_params(
# We now allow logprobs being true without top_logrobs. self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
logits_processors = None # We now allow logprobs being true without top_logrobs.
if self.logit_bias: logits_processors = get_logits_processors(
logit_bias: Dict[int, float] = {} logit_bias=self.logit_bias,
try: allowed_token_ids=None,
for token_id, bias in self.logit_bias.items(): tokenizer=tokenizer,
# Convert token_id to integer before we add to LLMEngine )
# Clamp the bias between -100 and 100 per OpenAI API spec if guided_decode_logits_processor:
logit_bias[int(token_id)] = min(100, max(-100, bias)) logits_processors.append(guided_decode_logits_processor)
except ValueError as exc:
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
f"but token_id must be an integer or string "
f"representing an integer") from exc
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams( return SamplingParams(
n=self.n, n=self.n,
...@@ -254,7 +264,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -254,7 +264,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
logprobs=self.top_logprobs if self.logprobs else None, logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None, prompt_logprobs=self.top_logprobs if self.echo else None,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens, max_tokens=max_tokens,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,
...@@ -333,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -333,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
n: int = 1 n: int = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None, seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
...@@ -358,6 +366,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -358,6 +366,7 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: begin-completion-extra-params # doc: begin-completion-extra-params
...@@ -407,30 +416,23 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -407,30 +416,23 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params # doc: end-completion-extra-params
def to_sampling_params(self): def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = None logits_processors = get_logits_processors(
if self.logit_bias: logit_bias=self.logit_bias,
logit_bias: Dict[int, float] = {} allowed_token_ids=self.allowed_token_ids,
try: tokenizer=tokenizer,
for token_id, bias in self.logit_bias.items(): )
# Convert token_id to integer if guided_decode_logits_processor:
# Clamp the bias between -100 and 100 per OpenAI API spec logits_processors.append(guided_decode_logits_processor)
logit_bias[int(token_id)] = min(100, max(-100, bias))
except ValueError as exc:
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
f"but token_id must be an integer or string "
f"representing an integer") from exc
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams( return SamplingParams(
n=self.n, n=self.n,
...@@ -447,7 +449,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -447,7 +449,7 @@ class CompletionRequest(OpenAIBaseModel):
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs, logprobs=self.logprobs,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1, max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,
......
from dataclasses import dataclass
from enum import Enum
from typing import Mapping, Optional, Union
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
VLLM_RPC_SUCCESS_STR = "SUCCESS"
VLLM_RPC_HEALTHY_STR = "HEALTHY"
@dataclass
class RPCGenerateRequest:
inputs: PromptInputs
sampling_params: SamplingParams
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
@dataclass
class RPCAbortRequest:
request_id: str
class RPCUtilityRequest(Enum):
IS_SERVER_READY = 1
GET_MODEL_CONFIG = 2
GET_DECODING_CONFIG = 3
GET_PARALLEL_CONFIG = 4
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
CHECK_HEALTH = 8
IS_TRACING_ENABLED = 9
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
RPCUtilityRequest]
from contextlib import contextmanager
from typing import Any, AsyncIterator, Mapping, Optional
import cloudpickle
import zmq
import zmq.asyncio
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
class AsyncEngineRPCClient:
def __init__(self, port: int):
self.context = zmq.asyncio.Context()
self.path = f"tcp://localhost:{port}"
async def setup(self):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await self.wait_for_server()
# Get the configs.
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()
self.tracing_flag = await self._is_tracing_enabled_rpc()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=(await self._get_scheduler_config_rpc()),
parallel_config=(await self._get_parallel_config_rpc()),
enable_lora=bool(await self._get_lora_config_rpc()),
)
def close(self):
"""Destroy the ZeroMQ Context."""
self.context.destroy()
@contextmanager
def socket(self):
# Ensure client sockets are always closed after use
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.path)
yield socket
finally:
socket.close()
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""
with self.socket() as socket:
# Ping RPCServer with a request.
await socket.send(cloudpickle.dumps(request))
# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
else:
raise ValueError(error_message)
return data
async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))
# Await acknowledgement from RPCServer.
response = cloudpickle.loads(await socket.recv())
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
raise ValueError(error_message)
return response
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def wait_for_server(self):
"""Wait for the RPCServer to start up."""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_MODEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server")
async def _get_decoding_config_rpc(self) -> DecodingConfig:
"""Get DecodingConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_DECODING_CONFIG,
expected_type=DecodingConfig,
error_message="Could not get DecodingConfig from RPC Server")
async def _get_parallel_config_rpc(self) -> ParallelConfig:
"""Get ParallelConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ParallelConfig,
error_message="Could not get ParallelConfig from RPC Server")
async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")
async def _get_lora_config_rpc(self):
"""Get LoRAConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")
async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
"""Get is_tracing_enabled flag from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.IS_TRACING_ENABLED,
expected_type=bool,
error_message="Could not get is_tracing_enabled flag from RPC "
"Server")
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
with self.socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
# Stream back the results from the RPC Server.
while True:
message = await socket.recv()
request_output = cloudpickle.loads(message)
if isinstance(request_output, Exception):
raise request_output
if request_output.finished:
break
yield request_output
yield request_output
async def check_health(self) -> None:
"""Raise if unhealthy"""
with self.socket() as socket:
# Ping RPCServer with CHECK_HEALTH request.
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
)
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message = cloudpickle.loads(await socket.recv())
if isinstance(health_message, Exception):
raise health_message
if health_message != VLLM_RPC_HEALTHY_STR:
raise ValueError("Expected healthy response from backend but got "
"f{health_message}")
async def encode(self, *args,
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
import asyncio
import signal
from typing import Any, Coroutine
import cloudpickle
import zmq
import zmq.asyncio
from typing_extensions import Never
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context)
# Initialize context.
self.context = zmq.asyncio.Context()
# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
# Note numeric form of localhost should be used for zmq bind(),
# see https://stackoverflow.com/a/8958414
self.socket.bind(f"tcp://127.0.0.1:{port}")
def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
async def get_model_config(self, identity):
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(model_config)])
async def get_decoding_config(self, identity):
"""Send the DecodingConfig"""
decoding_config = await self.engine.get_decoding_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(decoding_config)])
async def get_lora_config(self, identity):
lora_config = await self.engine.get_lora_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(lora_config)])
async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
results_generator = self.engine.generate(
generate_request.inputs,
sampling_params=generate_request.sampling_params,
request_id=generate_request.request_id,
lora_request=generate_request.lora_request,
trace_headers=generate_request.trace_headers,
prompt_adapter_request=generate_request.prompt_adapter_request)
async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
except Exception as e:
### Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
return self.get_model_config(identity)
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
return self.get_parallel_config(identity)
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
return self.get_decoding_config(identity)
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
return self.get_scheduler_config(identity)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.CHECK_HEALTH:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
else:
raise ValueError(f"Unknown RPCRequest type: {request}")
async def run_server_loop(self):
"""Inner RPC Server Loop"""
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
# Process the request async.
task = asyncio.create_task(
self._make_handler_coro(identity, message))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
async def run_server(server: AsyncEngineRPCServer):
# Put the server task into the asyncio loop.
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
# Interruption handling.
def signal_handler() -> None:
# Kill the server on interrupt / terminate
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("vLLM ZMQ RPC Server was interrupted.")
finally:
# Clean up all resources.
server.cleanup()
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
server = AsyncEngineRPCServer(async_engine_args, usage_context, port)
asyncio.run(run_server(server))
import time import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
...@@ -8,10 +7,10 @@ from fastapi import Request ...@@ -8,10 +7,10 @@ from fastapi import Request
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template, load_chat_template,
parse_chat_message_content) parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProb, ChatCompletionLogProbs,
...@@ -25,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, ...@@ -25,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath) PromptAdapterPath)
from vllm.inputs import PromptInputs from vllm.inputs import PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
...@@ -41,7 +38,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -41,7 +38,7 @@ class OpenAIServingChat(OpenAIServing):
def __init__( def __init__(
self, self,
engine: AsyncLLMEngine, async_engine_client: AsyncEngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
response_role: str, response_role: str,
...@@ -50,13 +47,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -50,13 +47,15 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapters: Optional[List[PromptAdapterPath]], prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
): ):
super().__init__(engine=engine, super().__init__(async_engine_client=async_engine_client,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules, lora_modules=lora_modules,
prompt_adapters=prompt_adapters, prompt_adapters=prompt_adapters,
request_logger=request_logger) request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
self.response_role = response_role self.response_role = response_role
...@@ -89,17 +88,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -89,17 +88,11 @@ class OpenAIServingChat(OpenAIServing):
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
model_config = self.model_config model_config = self.model_config
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
conversation: List[ConversationMessage] = [] conversation, mm_futures = parse_chat_messages(
mm_futures: List[Awaitable[MultiModalDataDict]] = [] request.messages, model_config, tokenizer)
for msg in request.messages:
chat_parsed_result = parse_chat_message_content(
msg, model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
tool_dicts = None if request.tools is None else [ tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools tool.model_dump() for tool in request.tools
...@@ -114,6 +107,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -114,6 +107,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template=request.chat_template or self.chat_template, chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
) )
assert isinstance(prompt, str)
except Exception as e: except Exception as e:
logger.error("Error in applying chat template from request: %s", e) logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -132,28 +126,23 @@ class OpenAIServingChat(OpenAIServing): ...@@ -132,28 +126,23 @@ class OpenAIServingChat(OpenAIServing):
request_id = f"chat-{random_uuid()}" request_id = f"chat-{random_uuid()}"
try: try:
sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = ( guided_decode_logits_processor = (
await await self._guided_decode_logits_processor(request, tokenizer))
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
prompt_inputs = self._tokenize_prompt_input( prompt_inputs = self._tokenize_prompt_input(
request, request,
tokenizer, tokenizer,
prompt, prompt,
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))
self._log_inputs(request_id, self._log_inputs(request_id,
prompt_inputs, prompt_inputs,
params=sampling_params, params=sampling_params,
...@@ -166,7 +155,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -166,7 +155,8 @@ class OpenAIServingChat(OpenAIServing):
if mm_data is not None: if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled() is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
trace_headers = None trace_headers = None
if is_tracing_enabled and raw_request: if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers) trace_headers = extract_trace_headers(raw_request.headers)
...@@ -174,7 +164,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -174,7 +164,7 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)): and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning() log_tracing_disabled_warning()
result_generator = self.engine.generate( result_generator = self.async_engine_client.generate(
engine_inputs, engine_inputs,
sampling_params, sampling_params,
request_id, request_id,
...@@ -247,7 +237,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -247,7 +237,15 @@ class OpenAIServingChat(OpenAIServing):
model=model_name) model=model_name)
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
chunk.usage = None if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
...@@ -277,7 +275,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -277,7 +275,18 @@ class OpenAIServingChat(OpenAIServing):
model=model_name) model=model_name)
if (request.stream_options and if (request.stream_options and
request.stream_options.include_usage): request.stream_options.include_usage):
chunk.usage = None if (request.stream_options.
continuous_usage_stats):
prompt_tokens = len(
res.prompt_token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json( data = chunk.model_dump_json(
exclude_unset=True) exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
...@@ -336,7 +345,19 @@ class OpenAIServingChat(OpenAIServing): ...@@ -336,7 +345,19 @@ class OpenAIServingChat(OpenAIServing):
model=model_name) model=model_name)
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
chunk.usage = None if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
else: else:
...@@ -356,7 +377,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -356,7 +377,18 @@ class OpenAIServingChat(OpenAIServing):
model=model_name) model=model_name)
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
chunk.usage = None if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
finish_reason_sent[i] = True finish_reason_sent[i] = True
...@@ -404,7 +436,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -404,7 +436,7 @@ class OpenAIServingChat(OpenAIServing):
async for res in result_generator: async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected(): if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await self.engine.abort(request_id) await self.async_engine_client.abort(request_id)
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res = res final_res = res
assert final_res is not None assert final_res is not None
...@@ -480,11 +512,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -480,11 +512,14 @@ class OpenAIServingChat(OpenAIServing):
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [ return [
ChatCompletionLogProb( ChatCompletionLogProb(token=(token := self._get_decoded_token(
token=(token := self._get_decoded_token(p[1], p[0], p[1],
tokenizer)), p[0],
logprob=max(p[1].logprob, -9999.0), tokenizer,
bytes=list(token.encode("utf-8", errors="replace"))) return_as_token_id=self.return_tokens_as_token_ids)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items()) for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs if top_logprobs and i < top_logprobs
] ]
...@@ -504,6 +539,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -504,6 +539,8 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=token, token=token,
...@@ -511,7 +548,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -511,7 +548,9 @@ class OpenAIServingChat(OpenAIServing):
else: else:
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=step_top_logprobs[token_id].decoded_token, token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer,
self.return_tokens_as_token_ids),
logprob=max(step_top_logprobs[token_id].logprob, logprob=max(step_top_logprobs[token_id].logprob,
-9999.0), -9999.0),
bytes=list( bytes=list(
......
...@@ -8,7 +8,7 @@ from fastapi import Request ...@@ -8,7 +8,7 @@ from fastapi import Request
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, ...@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing, OpenAIServing,
PromptAdapterPath) PromptAdapterPath)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
...@@ -44,20 +42,22 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -44,20 +42,22 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__( def __init__(
self, self,
engine: AsyncLLMEngine, async_engine_client: AsyncEngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
*, *,
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]], prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
): ):
super().__init__(engine=engine, super().__init__(async_engine_client=async_engine_client,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules, lora_modules=lora_modules,
prompt_adapters=prompt_adapters, prompt_adapters=prompt_adapters,
request_logger=request_logger) request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(self, request: CompletionRequest, async def create_completion(self, request: CompletionRequest,
raw_request: Request): raw_request: Request):
...@@ -91,33 +91,27 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -91,33 +91,27 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompts = list( prompts = list(
self._tokenize_prompt_input_or_inputs( self._tokenize_prompt_input_or_inputs(
request, request,
tokenizer, tokenizer,
request.prompt, request.prompt,
truncate_prompt_tokens=sampling_params. truncate_prompt_tokens=request.truncate_prompt_tokens,
truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
)) ))
for i, prompt_inputs in enumerate(prompts): for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
...@@ -126,7 +120,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -126,7 +120,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = await self.engine.is_tracing_enabled() is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
trace_headers = None trace_headers = None
if is_tracing_enabled: if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers) trace_headers = extract_trace_headers(raw_request.headers)
...@@ -134,7 +129,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -134,7 +129,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers): raw_request.headers):
log_tracing_disabled_warning() log_tracing_disabled_warning()
generator = self.engine.generate( generator = self.async_engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params, sampling_params,
request_id_item, request_id_item,
...@@ -175,7 +170,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -175,7 +170,7 @@ class OpenAIServingCompletion(OpenAIServing):
async for i, res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}") await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
...@@ -237,7 +232,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -237,7 +232,8 @@ class OpenAIServingCompletion(OpenAIServing):
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{prompt_idx}") await self.async_engine_client.abort(
f"{request_id}-{prompt_idx}")
raise StopAsyncIteration() raise StopAsyncIteration()
for output in res.outputs: for output in res.outputs:
...@@ -430,12 +426,17 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -430,12 +426,17 @@ class OpenAIServingCompletion(OpenAIServing):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(None) out_token_logprobs.append(None)
out_top_logprobs.append(None) out_top_logprobs.append(None)
else: else:
token = self._get_decoded_token(step_top_logprobs[token_id], token = self._get_decoded_token(
token_id, tokenizer) step_top_logprobs[token_id],
token_id,
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)
token_logprob = max(step_top_logprobs[token_id].logprob, token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0) -9999.0)
out_tokens.append(token) out_tokens.append(token)
...@@ -448,7 +449,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -448,7 +449,11 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({ out_top_logprobs.append({
# Convert float("-inf") to the # Convert float("-inf") to the
# JSON-serializable float that OpenAI uses # JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer): self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids):
max(top_lp[1].logprob, -9999.0) max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items()) for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i if num_output_top_logprobs >= i
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest, from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
...@@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing):
def __init__( def __init__(
self, self,
engine: AsyncLLMEngine, async_engine_client: AsyncEngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
): ):
super().__init__(engine=engine, super().__init__(async_engine_client=async_engine_client,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=None, lora_modules=None,
...@@ -99,7 +99,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -99,7 +99,8 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
...@@ -124,7 +125,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -124,7 +125,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported " "Prompt adapter is not supported "
"for embedding models") "for embedding models")
generator = self.engine.encode( generator = self.async_engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params, pooling_params,
request_id_item, request_id_item,
...@@ -146,7 +147,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -146,7 +147,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async for i, res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}") await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
......
...@@ -5,11 +5,10 @@ from http import HTTPStatus ...@@ -5,11 +5,10 @@ from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -26,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -26,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.inputs import parse_and_batch_prompt from vllm.inputs import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -49,8 +51,6 @@ class LoRAModulePath: ...@@ -49,8 +51,6 @@ class LoRAModulePath:
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest] EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict): class TextTokensPrompt(TypedDict):
prompt: str prompt: str
...@@ -61,17 +61,18 @@ class OpenAIServing: ...@@ -61,17 +61,18 @@ class OpenAIServing:
def __init__( def __init__(
self, self,
engine: AsyncLLMEngine, async_engine_client: AsyncEngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
*, *,
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]], prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
): ):
super().__init__() super().__init__()
self.engine = engine self.async_engine_client = async_engine_client
self.model_config = model_config self.model_config = model_config
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
...@@ -102,6 +103,7 @@ class OpenAIServing: ...@@ -102,6 +103,7 @@ class OpenAIServing:
prompt_adapter_num_virtual_tokens=num_virtual_tokens)) prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.request_logger = request_logger self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
...@@ -150,6 +152,15 @@ class OpenAIServing: ...@@ -150,6 +152,15 @@ class OpenAIServing:
}) })
return json_str return json_str
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.async_engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
guided_decoding_backend, request, tokenizer)
async def _check_model( async def _check_model(
self, self,
request: AnyRequest, request: AnyRequest,
...@@ -254,9 +265,7 @@ class OpenAIServing: ...@@ -254,9 +265,7 @@ class OpenAIServing:
f"{self.max_model_len} tokens. However, you requested " f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, " f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.") f"Please reduce the length of the messages.")
request.max_tokens = self.max_model_len - token_num elif token_num + request.max_tokens > self.max_model_len:
if token_num + request.max_tokens > self.max_model_len:
raise ValueError( raise ValueError(
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested " f"{self.max_model_len} tokens. However, you requested "
...@@ -384,11 +393,13 @@ class OpenAIServing: ...@@ -384,11 +393,13 @@ class OpenAIServing:
) )
@staticmethod @staticmethod
def _get_decoded_token( def _get_decoded_token(logprob: Logprob,
logprob: Logprob, token_id: int,
token_id: int, tokenizer: AnyTokenizer,
tokenizer: AnyTokenizer, return_as_token_id: bool = False) -> str:
) -> str: if return_as_token_id:
return f"token_id:{token_id}"
if logprob.decoded_token is not None: if logprob.decoded_token is not None:
return logprob.decoded_token return logprob.decoded_token
return tokenizer.decode(token_id) return tokenizer.decode(token_id)
from typing import List, Optional, Union from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (DetokenizeRequest, from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
ErrorResponse, ErrorResponse,
...@@ -17,14 +15,17 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest, ...@@ -17,14 +15,17 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__)
class OpenAIServingTokenization(OpenAIServing): class OpenAIServingTokenization(OpenAIServing):
def __init__( def __init__(
self, self,
engine: AsyncLLMEngine, async_engine_client: AsyncEngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
*, *,
...@@ -32,7 +33,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -32,7 +33,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
): ):
super().__init__(engine=engine, super().__init__(async_engine_client=async_engine_client,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules, lora_modules=lora_modules,
...@@ -57,17 +58,17 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -57,17 +58,17 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
model_config = self.model_config model_config = self.model_config
conversation: List[ConversationMessage] = [] conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)
for message in request.messages: if mm_futures:
result = parse_chat_message_content(message, model_config, logger.warning(
tokenizer) "Multi-modal inputs are ignored during tokenization")
conversation.extend(result.messages)
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
...@@ -113,7 +114,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -113,7 +114,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id, self._log_inputs(request_id,
request.tokens, request.tokens,
......
...@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional ...@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
VLLM_HOST_IP: str = "" VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None VLLM_PORT: Optional[int] = None
VLLM_RPC_PORT: int = 5570
VLLM_USE_MODELSCOPE: bool = False VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None VLLM_INSTANCE_ID: Optional[str] = None
...@@ -28,7 +29,9 @@ if TYPE_CHECKING: ...@@ -28,7 +29,9 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_OPENVINO_KVCACHE_SPACE: int = 0 VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
...@@ -36,6 +39,7 @@ if TYPE_CHECKING: ...@@ -36,6 +39,7 @@ if TYPE_CHECKING:
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True
VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
...@@ -43,10 +47,10 @@ if TYPE_CHECKING: ...@@ -43,10 +47,10 @@ if TYPE_CHECKING:
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
VLLM_INSTALL_PUNICA_KERNELS: bool = False
VLLM_NO_DEPRECATION_WARNING: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -92,10 +96,6 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -92,10 +96,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_PRECOMPILED": "VLLM_USE_PRECOMPILED":
lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")),
# If set, vllm will install Punica kernels
"VLLM_INSTALL_PUNICA_KERNELS":
lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))),
# CMake build type # CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo" # If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo"
...@@ -142,6 +142,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -142,6 +142,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv('VLLM_PORT', '0')) lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None, if 'VLLM_PORT' in os.environ else None,
# used when the frontend api server is running in multi-processing mode,
# to communicate with the backend engine process over ZMQ.
'VLLM_RPC_PORT':
lambda: int(os.getenv('VLLM_PORT', '5570')),
# If true, will load models from ModelScope instead of Hugging Face Hub. # If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers # note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE": "VLLM_USE_MODELSCOPE":
...@@ -181,6 +186,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -181,6 +186,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "False").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "False").lower() in
("true", "1")), ("true", "1")),
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": "LOCAL_RANK":
...@@ -246,11 +255,20 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -246,11 +255,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND": "VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
# CPU key-value cache space # Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# (CPU backend only) CPU key-value cache space.
# default is 4GB # default is 4GB
"VLLM_CPU_KVCACHE_SPACE": "VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
"VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
# OpenVINO key-value cache space # OpenVINO key-value cache space
# default is 4GB # default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE": "VLLM_OPENVINO_KVCACHE_SPACE":
...@@ -272,13 +290,20 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -272,13 +290,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# execution on all workers. # execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER": "VLLM_USE_RAY_SPMD_WORKER":
lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)), lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))),
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
"VLLM_USE_RAY_COMPILED_DAG": "VLLM_USE_RAY_COMPILED_DAG":
lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)), lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))),
# If the env var is set, it uses NCCL for communication in
# Ray's compiled DAG. This flag is ignored if
# VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1"))
),
# Use dedicated multiprocess context for workers. # Use dedicated multiprocess context for workers.
# Both spawn and fork work # Both spawn and fork work
...@@ -312,6 +337,15 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -312,6 +337,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vllm will skip the deprecation warnings. # If set, vllm will skip the deprecation warnings.
"VLLM_NO_DEPRECATION_WARNING": "VLLM_NO_DEPRECATION_WARNING":
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than
# the max length derived from the model's config.json.
# To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.
"VLLM_ALLOW_LONG_MAX_MODEL_LEN":
lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")),
} }
# end-env-vars-definition # end-env-vars-definition
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment