Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
......@@ -30,55 +30,55 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions
class Metrics:
labelname_finish_reason = "finished_reason"
_base_library = prometheus_client
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
self._unregister_vllm_metrics()
# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
self._create_info_cache_config()
# System stats
# Scheduler State
self.gauge_scheduler_running = self._base_library.Gauge(
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = self._base_library.Gauge(
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_scheduler_swapped = self._base_library.Gauge(
self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
# KV Cache Usage in %
self.gauge_gpu_cache_usage = self._base_library.Gauge(
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
self.gauge_cpu_cache_usage = self._base_library.Gauge(
self.gauge_cpu_cache_usage = self._gauge_cls(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
# Iteration stats
self.counter_num_preemption = self._base_library.Counter(
self.counter_num_preemption = self._counter_cls(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = self._base_library.Counter(
self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = self._base_library.Counter(
self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = self._base_library.Histogram(
self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
......@@ -86,7 +86,7 @@ class Metrics:
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = self._base_library.Histogram(
self.histogram_time_per_output_token = self._histogram_cls(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
......@@ -97,59 +97,157 @@ class Metrics:
# Request stats
# Latency
self.histogram_e2e_time_request = self._base_library.Histogram(
self.histogram_e2e_time_request = self._histogram_cls(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata
self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
self.histogram_num_prompt_tokens_request = self._histogram_cls(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_num_generation_tokens_request = \
self._base_library.Histogram(
self._histogram_cls(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = self._base_library.Histogram(
self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self._base_library.Histogram(
self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.counter_request_success = self._base_library.Counter(
self.counter_request_success = self._counter_cls(
name="vllm:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Speculatie decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames)
self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames)
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))
# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
self.gauge_avg_prompt_throughput = self._gauge_cls(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = self._base_library.Gauge(
self.gauge_avg_generation_throughput = self._gauge_cls(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)
def _create_info_cache_config(self) -> None:
# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
def _unregister_vllm_metrics(self) -> None:
for collector in list(self._base_library.REGISTRY._collector_to_names):
for collector in list(prometheus_client.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
self._base_library.REGISTRY.unregister(collector)
prometheus_client.REGISTRY.unregister(collector)
# end-metrics-definitions
class _RayGaugeWrapper:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._gauge = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def labels(self, **labels):
self._gauge.set_default_tags(labels)
return self
def set(self, value: Union[int, float]):
return self._gauge.set(value)
class _RayCounterWrapper:
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._counter = ray_metrics.Counter(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def labels(self, **labels):
self._counter.set_default_tags(labels)
return self
def inc(self, value: Union[int, float] = 1.0):
if value == 0:
return
return self._counter.inc(value)
class _RayHistogramWrapper:
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=buckets)
def labels(self, **labels):
self._histogram.set_default_tags(labels)
return self
def observe(self, value: Union[int, float]):
return self._histogram.observe(value)
class RayMetrics(Metrics):
......@@ -157,7 +255,9 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library = ray_metrics
_gauge_cls = _RayGaugeWrapper
_counter_cls = _RayCounterWrapper
_histogram_cls = _RayHistogramWrapper
def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
......@@ -168,8 +268,9 @@ class RayMetrics(Metrics):
# No-op on purpose
pass
# end-metrics-definitions
def _create_info_cache_config(self) -> None:
# No-op on purpose
pass
def build_1_2_5_buckets(max_value: int) -> List[int]:
......@@ -325,8 +426,8 @@ class LoggingStatLogger(StatLoggerBase):
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
f"Number of draft tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens: {metrics.emitted_tokens}.")
class PrometheusStatLogger(StatLoggerBase):
......@@ -454,7 +555,26 @@ class PrometheusStatLogger(StatLoggerBase):
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
stats.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
stats.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
stats.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
stats.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
stats.spec_decode_metrics.emitted_tokens)
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics
\ No newline at end of file
_metrics_cls = RayMetrics
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
return None
......@@ -90,7 +90,11 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# Guard against a KeyError which can occur if the request was
# aborted while the output was generated
if (child_list :=
parent_child_dict.get(sample.parent_seq_id)) is not None:
child_list.append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
......
import codecs
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, Union, cast, final
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import ChatCompletionContentPartImageParam
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import ChatCompletionContentPartTextParam
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer
from typing_extensions import Required, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
logger = init_logger(__name__)
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
type: Required[str]
"""The type of the content part."""
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
CustomChatCompletionContentPartParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
return resolved_chat_template
@lru_cache(maxsize=None)
def _image_token_str(model_config: ModelConfig,
tokenizer: PreTrainedTokenizer) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type == "chameleon":
return "<image>"
raise TypeError("Unknown model type: {model_type}")
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return f"{image_token_str}\n{text_prompt}"
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_futures:
image_token_str = _image_token_str(model_config, tokenizer)
if image_token_str is not None:
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_image_text_prompt(
image_token_str=image_token_str,
text_prompt=text_prompt,
)
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[], mm_futures=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
......@@ -6,8 +6,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
TextTokensPrompt, TokensPrompt,
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -69,6 +68,10 @@ class LLM:
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data
transfer for every forward pass.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
......@@ -114,6 +117,7 @@ class LLM:
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
......@@ -141,6 +145,7 @@ class LLM:
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
......@@ -232,7 +237,7 @@ class LLM:
@overload
def generate(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
......@@ -249,7 +254,7 @@ class LLM:
"instead.")
def generate(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
......@@ -296,9 +301,7 @@ class LLM:
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if sampling_params is None:
# Use default sampling params.
......@@ -377,7 +380,7 @@ class LLM:
@overload
def encode(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
......@@ -394,7 +397,7 @@ class LLM:
"instead.")
def encode(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
......@@ -411,7 +414,7 @@ class LLM:
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
......@@ -440,9 +443,7 @@ class LLM:
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if pooling_params is None:
# Use default pooling params.
......@@ -490,17 +491,11 @@ class LLM:
inputs: List[PromptInputs] = []
for i in range(num_requests):
if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i])
item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
raise AssertionError
inputs.append(item)
......@@ -508,7 +503,7 @@ class LLM:
def _validate_and_add_requests(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
......
from typing import List, Optional, Union
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
class RequestLogger:
def __init__(self, *, max_log_len: Optional[int]) -> None:
super().__init__()
self.max_log_len = max_log_len
def log_inputs(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
params: Optional[Union[SamplingParams, PoolingParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if prompt is not None:
prompt = prompt[:max_log_len]
if prompt_token_ids is not None:
prompt_token_ids = prompt_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id,
prompt, params, prompt_token_ids, lora_request,
prompt_adapter_request)
......@@ -18,6 +18,7 @@ from starlette.routing import Mount
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
......@@ -33,6 +34,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
......@@ -40,12 +43,12 @@ from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
engine: AsyncLLMEngine
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
logger = init_logger('vllm.entrypoints.openai.api_server')
......@@ -70,11 +73,13 @@ async def lifespan(app: fastapi.FastAPI):
router = APIRouter()
# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
router.routes.append(route)
def mount_metrics(app: fastapi.FastAPI):
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(metrics_route)
@router.get("/health")
......@@ -86,7 +91,7 @@ async def health() -> Response:
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
generator = await openai_serving_tokenization.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
......@@ -97,7 +102,7 @@ async def tokenize(request: TokenizeRequest):
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
generator = await openai_serving_tokenization.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
......@@ -164,6 +169,8 @@ def build_app(args):
app.include_router(router)
app.root_path = args.root_path
mount_metrics(app)
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
......@@ -238,20 +245,48 @@ def run_server(args, llm_engine=None):
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
)
openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules,
args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
engine,
model_config,
served_model_names,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
engine,
model_config,
served_model_names,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
)
app.root_path = args.root_path
logger.info("Available routes are:")
......
......@@ -130,6 +130,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser
......
......@@ -3,50 +3,16 @@
import time
from typing import Any, Dict, List, Literal, Optional, Union
import openai.types.chat
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Annotated, Required, TypedDict
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
type: Required[str]
"""The type of the content part."""
ChatCompletionContentPartParam = Union[
openai.types.chat.ChatCompletionContentPartParam,
CustomChatCompletionContentPartParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
ChatCompletionMessageParam = Union[
openai.types.chat.ChatCompletionMessageParam,
CustomChatCompletionMessageParam]
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
......@@ -155,40 +121,42 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo: Optional[bool] = Field(
echo: bool = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: Optional[bool] = Field(
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
add_special_tokens: Optional[bool] = Field(
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to False (as is the "
"special tokens so this should be set to false (as is the "
"default)."),
)
documents: Optional[List[Dict[str, str]]] = Field(
......@@ -212,12 +180,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
......@@ -278,22 +240,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
......@@ -301,6 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode='before')
......@@ -382,26 +345,27 @@ class CompletionRequest(OpenAIBaseModel):
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
include_stop_str_in_output: Optional[bool] = Field(
default=False,
add_special_tokens: bool = Field(
default=True,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
......@@ -481,15 +445,15 @@ class CompletionRequest(OpenAIBaseModel):
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
......@@ -523,11 +487,11 @@ class CompletionRequest(OpenAIBaseModel):
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when stream is True.")
"Stream options can only be defined when stream is true.")
return data
class EmbeddingRequest(BaseModel):
class EmbeddingRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
......@@ -599,13 +563,13 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)
class EmbeddingResponseData(BaseModel):
class EmbeddingResponseData(OpenAIBaseModel):
index: int
object: str = "embedding"
embedding: Union[List[float], str]
class EmbeddingResponse(BaseModel):
class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
......@@ -704,8 +668,8 @@ class BatchRequestInput(OpenAIBaseModel):
# /v1/chat/completions is supported.
url: str
# The parameteters of the request.
body: Union[ChatCompletionRequest, ]
# The parameters of the request.
body: ChatCompletionRequest
class BatchResponseData(OpenAIBaseModel):
......@@ -737,16 +701,28 @@ class BatchRequestOutput(OpenAIBaseModel):
error: Optional[Any]
class TokenizeRequest(OpenAIBaseModel):
class TokenizeCompletionRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)
class TokenizeChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True)
add_special_tokens: bool = Field(default=False)
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
class TokenizeResponse(OpenAIBaseModel):
tokens: List[int]
count: int
max_model_len: int
tokens: List[int]
class DetokenizeRequest(OpenAIBaseModel):
......
......@@ -6,6 +6,7 @@ import aiohttp
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
......@@ -44,9 +45,17 @@ def parse_args():
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
"`request.add_generation_prompt=True`.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser.parse_args()
......@@ -114,11 +123,20 @@ async def main(args):
# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config()
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
args.response_role,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
)
# Submit all requests in the file to the engine "concurrently".
......
import codecs
import time
from dataclasses import dataclass, field
from functools import cached_property
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
List, Optional)
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
Optional)
from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final
from typing import Union
from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionLogProb,
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionLogProb, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
OpenAIServing,
PromptAdapterPath)
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
......@@ -38,159 +37,31 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
self.response_role = response_role
self._load_chat_template(chat_template)
def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer
if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
@cached_property
def image_token_str(self) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self.model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
"paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return self.tokenizer.decode(
self.model_config.hf_config.image_token_index)
else:
raise TypeError("Unknown model type: {model_type}")
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(self, image_token_str: str,
text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return f"{image_token_str}\n{text_prompt}"
def _parse_chat_message_content_parts(
self,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported."
)
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_futures:
image_token_str = self.image_token_str
if image_token_str is not None:
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = self._get_full_image_text_prompt(
image_token_str=image_token_str,
text_prompt=text_prompt,
)
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def _parse_chat_message_content(
self,
message: ChatCompletionMessageParam,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[], mm_futures=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return self._parse_chat_message_content_parts(role, content)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_chat_completion(
self,
......@@ -212,11 +83,20 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in request.messages:
chat_parsed_result = self._parse_chat_message_content(msg)
chat_parsed_result = parse_chat_message_content(
msg, model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
......@@ -225,13 +105,13 @@ class OpenAIServingChat(OpenAIServing):
tool.model_dump() for tool in request.tools
]
prompt = self.tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
......@@ -250,61 +130,71 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
request_id = f"chat-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
_, lora_request = self._maybe_get_adapter(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
self._log_inputs(request_id,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = {
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
}
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if (not is_tracing_enabled and raw_request
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
inputs: PromptInputs = {
"prompt": prompt_text,
"prompt_token_ids": prompt_ids,
}
if mm_data:
inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and raw_request and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
inputs,
sampling_params,
request_id,
lora_request,
trace_headers=trace_headers,
)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation)
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation)
conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......@@ -316,9 +206,12 @@ class OpenAIServingChat(OpenAIServing):
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
......@@ -326,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = True
# Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
try:
async for res in result_generator:
# We need to do it here, because if there are exceptions in
......@@ -339,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
for i in range(request.n):
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
......@@ -367,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
for i in range(request.n):
for i in range(num_choices):
choice_data = (
ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=last_msg_content),
logprobs=None,
finish_reason=None))
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
......@@ -405,6 +299,7 @@ class OpenAIServingChat(OpenAIServing):
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
......@@ -493,9 +388,13 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
......@@ -523,6 +422,7 @@ class OpenAIServingChat(OpenAIServing):
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
)
else:
logprobs = None
......@@ -577,16 +477,14 @@ class OpenAIServingChat(OpenAIServing):
return response
def _get_top_logprobs(
self, logprobs: Dict[int, Logprob],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
bytes=list(token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
......@@ -595,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
......@@ -604,12 +503,11 @@ class OpenAIServingChat(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
......@@ -620,6 +518,7 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
step_top_logprobs, num_output_top_logprobs,
tokenizer)))
return ChatCompletionLogProbs(content=logprobs_content)
......@@ -2,12 +2,14 @@ import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
from typing import Tuple, cast
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
......@@ -16,10 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
DetokenizeRequest,
DetokenizeResponse,
TokenizeRequest,
TokenizeResponse, UsageInfo)
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
......@@ -41,40 +40,24 @@ TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
return prompt_is_tokens, prompts
class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]]):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters)
prompt_adapters=prompt_adapters,
request_logger=request_logger)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
......@@ -103,41 +86,45 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
sampling_params = request.to_sampling_params()
adapter_type, adapter_request = self._maybe_get_adapter(request)
lora_request, prompt_adapter_request = None, None
if adapter_type == 'LoRA':
lora_request, prompt_adapter_request = adapter_request, None
elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
......@@ -148,12 +135,9 @@ class OpenAIServingCompletion(OpenAIServing):
log_tracing_disabled_warning()
generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
f"{request_id}-{i}",
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
......@@ -182,7 +166,8 @@ class OpenAIServingCompletion(OpenAIServing):
request_id,
created_time,
model_name,
num_prompts=len(prompts))
num_prompts=len(prompts),
tokenizer=tokenizer)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
......@@ -193,8 +178,27 @@ class OpenAIServingCompletion(OpenAIServing):
await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
final_res.prompt = prompts[i]["prompt"]
final_res_batch_checked = cast(List[RequestOutput],
final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name)
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......@@ -221,11 +225,12 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
try:
async for prompt_idx, res in result_generator:
......@@ -236,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
......@@ -271,6 +276,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]),
)
else:
......@@ -339,12 +345,13 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: PreTrainedTokenizer,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
......@@ -370,6 +377,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
)
else:
......@@ -407,6 +415,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
......@@ -420,13 +429,13 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id)
token_id, tokenizer)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
......@@ -439,7 +448,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0]):
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
......@@ -457,29 +466,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request, prompt_ids=request.tokens)
return DetokenizeResponse(prompt=input_text)
import base64
import time
from typing import AsyncIterator, List, Optional, Tuple
from typing import AsyncIterator, List, Optional, Tuple, cast
import numpy as np
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
......@@ -28,11 +28,11 @@ def request_output_to_embedding_response(
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding
if encoding_format == "base64":
embedding = base64.b64encode(np.array(embedding))
embedding_bytes = np.array(embedding).tobytes()
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)
......@@ -54,12 +54,20 @@ def request_output_to_embedding_response(
class OpenAIServingEmbedding(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str]):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None)
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest,
......@@ -80,32 +88,47 @@ class OpenAIServingEmbedding(OpenAIServing):
"dimensions is currently not supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
request_id = f"embd-{random_uuid()}"
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
try:
prompt_is_tokens, prompts = parse_prompt_format(request.input)
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
pooling_params = request.to_pooling_params()
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt=prompt)
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.input,
))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_ids, prompt_text = prompt_formats
if prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported "
"for embedding models")
generator = self.engine.encode(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params,
f"{request_id}-{i}",
request_id_item,
lora_request=lora_request,
)
generators.append(generator)
......@@ -124,11 +147,17 @@ class OpenAIServingEmbedding(OpenAIServing):
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for final_res in final_res_batch:
assert final_res is not None
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch)
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name,
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
......
import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission, TokenizeRequest)
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest)
# yapf: enable
from vllm.inputs import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer
logger = init_logger(__name__)
......@@ -32,7 +43,18 @@ class PromptAdapterPath:
@dataclass
class LoRAModulePath:
name: str
local_path: str
path: str
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: List[int]
class OpenAIServing:
......@@ -42,8 +64,10 @@ class OpenAIServing:
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
):
super().__init__()
......@@ -51,14 +75,6 @@ class OpenAIServing:
self.model_config = model_config
self.max_model_len = model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
truncation_side="left")
self.served_model_names = served_model_names
self.lora_requests = []
......@@ -67,15 +83,15 @@ class OpenAIServing:
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
lora_path=lora.path,
) for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with open(f"./{prompt_adapter.local_path}"
f"/adapter_config.json") as f:
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
......@@ -85,6 +101,8 @@ class OpenAIServing:
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.request_logger = request_logger
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
......@@ -133,9 +151,8 @@ class OpenAIServing:
return json_str
async def _check_model(
self, request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest]
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
......@@ -151,62 +168,65 @@ class OpenAIServing:
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
def _maybe_get_adapters(
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if request.model in self.served_model_names:
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return 'LoRA', lora
return lora, None
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return 'PromptAdapter', prompt_adapter
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
Field(ge=1)]] = None,
add_special_tokens: Optional[bool] = True
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None:
# When using OpenAIServingChat for chat completions, for
# most models the special tokens (e.g., BOS) have already
# been added by the chat template. Therefore, we do not
# need to add them again.
# Set add_special_tokens to False (by default) to avoid
# adding the BOS tokens again.
tokenizer_kwargs: Dict[str, Any] = {
"add_special_tokens": add_special_tokens
}
if truncate_prompt_tokens is not None:
tokenizer_kwargs.update({
"truncation": True,
"max_length": truncate_prompt_tokens,
})
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
else:
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
input_text = prompt if prompt is not None else self.tokenizer.decode(
prompt_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
......@@ -216,13 +236,16 @@ class OpenAIServing:
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.", )
return input_ids, input_text
f"generation. Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
return input_ids, input_text
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
if request.max_tokens is None:
if token_num >= self.max_model_len:
......@@ -230,7 +253,7 @@ class OpenAIServing:
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.", )
f"Please reduce the length of the messages.")
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
......@@ -240,11 +263,132 @@ class OpenAIServing:
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, List[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes single input.
"""
return next(
self._tokenize_prompt_inputs(
request,
tokenizer,
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
))
def _tokenize_prompt_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes multiple inputs.
"""
for text in prompt_inputs:
if isinstance(text, str):
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=text,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _tokenize_prompt_input_or_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _log_inputs(
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if self.request_logger is None:
return
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = None
elif isinstance(inputs, list):
prompt = None
prompt_token_ids = inputs
else:
return input_ids, input_text
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
@staticmethod
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return self.tokenizer.decode(token_id)
return tokenizer.decode(token_id)
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.utils import random_uuid
class OpenAIServingTokenization(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=None,
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_tokenize(
self,
request: TokenizeRequest,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation: List[ConversationMessage] = []
for message in request.messages:
result = parse_chat_message_content(message, model_config,
tokenizer)
conversation.extend(result.messages)
prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
conversation=conversation,
tokenize=False,
chat_template=self.chat_template)
assert isinstance(prompt, str)
else:
prompt = request.prompt
self._log_inputs(request_id,
prompt,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self,
request: DetokenizeRequest,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
self._log_inputs(request_id,
request.tokens,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for tokenization")
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
request.tokens,
)
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text)
......@@ -17,7 +17,8 @@ if TYPE_CHECKING:
S3_ACCESS_KEY_ID: Optional[str] = None
S3_SECRET_ACCESS_KEY: Optional[str] = None
S3_ENDPOINT_URL: Optional[str] = None
VLLM_CONFIG_ROOT: str = ""
VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm")
VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm")
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
VLLM_NO_USAGE_STATS: bool = False
VLLM_DO_NOT_TRACK: bool = False
......@@ -31,10 +32,12 @@ if TYPE_CHECKING:
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
......@@ -45,6 +48,21 @@ if TYPE_CHECKING:
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
def get_default_cache_root():
return os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache"),
)
def get_default_config_root():
return os.getenv(
"XDG_CONFIG_HOME",
os.path.join(os.path.expanduser("~"), ".config"),
)
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
......@@ -89,15 +107,28 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv('VERBOSE', '0'))),
# Root directory for VLLM configuration files
# Defaults to `~/.config/vllm` unless `XDG_CONFIG_HOME` is set
# Note that this not only affects how vllm finds its configuration files
# during runtime, but also affects how vllm installs its configuration
# files during **installation**.
"VLLM_CONFIG_ROOT":
lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv(
"XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"),
lambda: os.path.expanduser(
os.getenv(
"VLLM_CONFIG_ROOT",
os.path.join(get_default_config_root(), "vllm"),
)),
# ================== Runtime Env Vars ==================
# Root directory for VLLM cache files
# Defaults to `~/.cache/vllm` unless `XDG_CACHE_HOME` is set
"VLLM_CACHE_ROOT":
lambda: os.path.expanduser(
os.getenv(
"VLLM_CACHE_ROOT",
os.path.join(get_default_cache_root(), "vllm"),
)),
# used in distributed environment to determine the master address
'VLLM_HOST_IP':
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
......@@ -231,6 +262,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)),
# If the env var is set, then all workers will execute as separate
# processes from the engine, and we use the same mechanism to trigger
# execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER":
lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)),
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
......@@ -242,6 +280,14 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "fork"),
# Path to the cache for storing downloaded assets
"VLLM_ASSETS_CACHE":
lambda: os.path.expanduser(
os.getenv(
"VLLM_ASSETS_CACHE",
os.path.join(get_default_cache_root(), "vllm", "assets"),
)),
# Timeout for fetching images when serving multimodal models
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT":
......@@ -250,7 +296,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH":
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
lambda: os.path.expanduser(
os.getenv(
"VLLM_ASSETS_CACHE",
os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
)),
"VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")),
......@@ -262,7 +312,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# end-env-vars-definition
def __getattr__(name):
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
......
......@@ -17,6 +17,8 @@ logger = init_logger(__name__)
class CPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
......
......@@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
......@@ -73,7 +73,9 @@ class DistributedGPUExecutor(GPUExecutor):
**self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req)
driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
......
import asyncio
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
......@@ -19,6 +18,8 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices.
"""
uses_ray: bool # whether the executor uses Ray for orchestration.
def __init__(
self,
model_config: ModelConfig,
......@@ -132,26 +133,6 @@ class ExecutorBase(ABC):
class ExecutorAsyncBase(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
) -> None:
self.pp_locks: Optional[List[asyncio.Lock]] = None
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, load_config,
lora_config, multimodal_config, speculative_config,
prompt_adapter_config)
@abstractmethod
async def execute_model_async(
self,
......
......@@ -12,8 +12,19 @@ from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
def create_worker(worker_module_name, worker_class_name, **kwargs):
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**kwargs)
return wrapper.worker
class GPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
......@@ -51,25 +62,30 @@ class GPUExecutor(ExecutorBase):
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
def _get_create_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict:
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
if self.speculative_config is None:
worker_kwargs.update(worker_module_name="vllm.worker.worker",
worker_class_name="Worker")
else:
worker_kwargs.update(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
return worker_kwargs
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
return create_worker(**self._get_create_worker_kwargs(
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method))
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
......
import asyncio
import os
import signal
import weakref
from functools import partial
from typing import Any, List, Optional
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (cuda_device_count_stateless,
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
error_on_invalid_device_count_status,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
......@@ -22,9 +25,12 @@ logger = init_logger(__name__)
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"""Python multiprocessing-based multi-GPU executor"""
uses_ray: bool = False
def _init_executor(self) -> None:
# Create the parallel GPU workers.
world_size = self.parallel_config.tensor_parallel_size
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
......@@ -47,8 +53,15 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if world_size > 1:
maybe_set_triton_cache_manager()
assert world_size <= cuda_device_count_stateless(), (
"please set tensor_parallel_size to less than max local gpu count")
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")
assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")
error_on_invalid_device_count_status()
......@@ -58,26 +71,53 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
self.workers: List[ProcessWorkerWrapper] = []
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[ProcessWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[ProcessWorkerWrapper] = []
if world_size == 1:
self.workers = []
self.worker_monitor = None
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
for rank in range(1, world_size):
worker = ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)) for rank in range(1, world_size)
]
create_worker,
**self._get_create_worker_kwargs(
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)))
self.workers.append(worker)
if rank % tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)
def shutdown(signum, frame):
if executor := ref():
executor.shutdown()
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")
......@@ -121,16 +161,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
if async_run_tensor_parallel_workers_only:
# Run only non-driver workers and just return futures.
return [
worker.execute_method(method, *args, **kwargs)
for worker in self.non_driver_workers
]
# Start all remote workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return worker_outputs
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
......@@ -157,16 +200,45 @@ class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
self.pp_locks: Optional[List[asyncio.Lock]] = None
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_model(execute_model_req)
if not self.tp_driver_workers:
return await self.driver_exec_model(execute_model_req)
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method_async,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.workers
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment