Unverified Commit baaedfdb authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable following imports for entrypoints (#7248)


Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarFei <dfdfcai4@gmail.com>
parent 45066412
......@@ -38,7 +38,6 @@ jobs:
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
......
......@@ -6,7 +6,7 @@ sphinx-argparse==0.4.0
msgspec
# packages to install to build the documentation
pydantic
pydantic >= 2.8
-f https://download.pytorch.org/whl/cpu
torch
py-cpuinfo
......
......@@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
......
......@@ -56,6 +56,7 @@ files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
......
......@@ -11,7 +11,7 @@ fastapi
aiohttp
openai >= 1.0 # Ensure modern openai package (ensure types module present)
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pydantic >= 2.8 # Required for OpenAI server.
pillow # Required for image processing
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
......
# imports for guided decoding tests
import json
import re
from typing import List
from typing import Dict, List, Optional
import jsonschema
import openai # use the official client for correctness check
......@@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: Optional[int]):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.chat.completions.create(**params)
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs is not None:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
......
......@@ -3,7 +3,7 @@ import json
import re
import shutil
from tempfile import TemporaryDirectory
from typing import Dict, List
from typing import Dict, List, Optional
import jsonschema
import openai # use the official client for correctness check
......@@ -268,92 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str, prompt_logprobs: int):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.chat.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0),
......@@ -361,7 +275,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
(MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: int):
prompt_logprobs: Optional[int]):
params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
......@@ -369,17 +283,12 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
if prompt_logprobs is not None:
assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0
......
......@@ -6,7 +6,6 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
import torch
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
import vllm.envs as envs
......@@ -31,6 +30,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
......@@ -427,8 +427,8 @@ class _AsyncLLMEngine(LLMEngine):
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
......@@ -771,7 +771,7 @@ class AsyncLLMEngine:
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
) -> AnyTokenizer:
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)
......
......@@ -3,9 +3,9 @@ from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, TypeVar, Union
from typing import Set, Tuple, Type, Union
from typing_extensions import assert_never
from typing_extensions import TypeVar, assert_never
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
......@@ -43,8 +43,9 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device
......@@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
return config.to_diff_dict()
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int],
......@@ -494,11 +496,20 @@ class LLMEngine:
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
group_type: Type[_G] = BaseTokenizerGroup,
*,
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
raise ValueError(missing_msg)
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return self.tokenizer
return tokenizer_group
def get_tokenizer(
self,
......@@ -693,8 +704,8 @@ class LLMEngine:
* prompt token ids
'''
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return tokenizer.encode(request_id=request_id,
prompt=prompt,
......
from abc import ABC, abstractmethod
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
......@@ -29,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
......
import functools
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
......@@ -12,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
......@@ -36,7 +35,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: StopChecker,
):
self.detokenizer = detokenizer
......
from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker:
......@@ -15,8 +14,7 @@ class StopChecker:
"""
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]):
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
......
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
runtime_checkable)
from transformers import PreTrainedTokenizer
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
......@@ -12,6 +10,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
......@@ -40,6 +39,7 @@ class AsyncEngineClient(Protocol):
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request"""
...
def encode(
self,
......@@ -50,6 +50,7 @@ class AsyncEngineClient(Protocol):
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...
async def abort(self, request_id: str) -> None:
"""Abort a request.
......@@ -60,25 +61,29 @@ class AsyncEngineClient(Protocol):
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
...
async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine."""
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer:
"""Get the appropriate Tokenizer for the request"""
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
...
async def is_tracing_enabled(self) -> bool:
pass
...
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass
...
async def check_health(self) -> None:
"""Raise if unhealthy"""
...
......@@ -61,6 +61,7 @@ async def generate(request: Request) -> Response:
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [
prompt + output.text for output in request_output.outputs
]
......@@ -80,6 +81,7 @@ async def generate(request: Request) -> Response:
assert final_output is not None
prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
......@@ -115,6 +117,7 @@ async def run_server(args: Namespace,
logger.info("args: %s", args)
app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http(
app,
......
......@@ -3,7 +3,7 @@ from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union, cast)
Union)
# yapf conflicts with isort for this block
# yapf: disable
......@@ -15,9 +15,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer
from typing_extensions import Required, TypedDict
from pydantic import ConfigDict, TypeAdapter
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
......@@ -50,9 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam]
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam, ]
class CustomChatCompletionMessageParam(TypedDict, total=False):
......@@ -114,7 +113,7 @@ def load_chat_template(
@lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
......@@ -151,11 +150,16 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str,
return f"{placeholder_token_str}\n{text_prompt}"
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
......@@ -164,7 +168,7 @@ def _parse_chat_message_content_parts(
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
text = _TextParser.validate_python(part)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
......@@ -172,8 +176,7 @@ def _parse_chat_message_content_parts(
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
image_url = _ImageParser.validate_python(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
......@@ -188,8 +191,7 @@ def _parse_chat_message_content_parts(
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = cast(ChatCompletionContentPartAudioParam,
part)["audio_url"]
audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
else:
......@@ -219,7 +221,7 @@ def _parse_chat_message_content_parts(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
......@@ -230,14 +232,18 @@ def _parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
return _parse_chat_message_content_parts(
role,
content, # type: ignore
model_config,
tokenizer,
)
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
......
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
......@@ -20,7 +19,9 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs
......@@ -122,7 +123,7 @@ class LLM:
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
......@@ -175,22 +176,19 @@ class LLM:
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def get_tokenizer(
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer.tokenizer
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
self.llm_engine.tokenizer.tokenizer = tokenizer
tokenizer_group.tokenizer = tokenizer
else:
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
tokenizer)
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
@overload # LEGACY: single (prompt + optional token ids)
def generate(
......@@ -578,6 +576,8 @@ class LLM:
inputs: List[PromptInputs] = []
for i in range(num_requests):
item: PromptInputs
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None:
......@@ -635,7 +635,7 @@ class LLM:
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
request_id = str(next(self.request_counter))
......
......@@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
......@@ -29,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
......@@ -90,7 +93,8 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
......@@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
logger.info("Started engine process with PID %d",
rpc_server_process.pid)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(rpc_path)
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
try:
while True:
try:
await async_engine_client.setup()
await rpc_client.setup()
break
except TimeoutError as e:
if not rpc_server_process.is_alive():
......@@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
rpc_client.close()
# Wait for server process to join
rpc_server_process.join()
......@@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
......@@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models")
async def show_available_models():
......@@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
......@@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
......@@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
......
......@@ -7,6 +7,7 @@ purposes.
import argparse
import json
import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
......@@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: List[LoRAModulePath] = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
......@@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action):
class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
adapter_list = []
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: List[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
......
......@@ -2,9 +2,9 @@ from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer
from vllm.sampling_params import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
class AllowedTokenIdsLogitsProcessor:
......@@ -53,8 +53,9 @@ def logit_bias_logits_processor(
def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
logits_processors = []
tokenizer: AnyTokenizer,
) -> List[LogitsProcessor]:
logits_processors: List[LogitsProcessor] = []
if logit_bias:
try:
# Convert token_id to integer
......
......@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
......@@ -14,11 +13,13 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
_LONG_INFO: Union["torch.iinfo", Namespace]
try:
from sphinx.ext.autodoc.mock import _MockModule
......@@ -235,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
......@@ -251,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
return SamplingParams(
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
......@@ -265,8 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
(self.top_logprobs if self.echo else None),
prompt_logprobs=prompt_logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
......@@ -280,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, values):
if (values.get('stream_options') is not None
and not values.get('stream')):
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"stream_options can only be set if stream is true")
return values
"Stream options can only be defined when `stream=True`.")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
return data
@model_validator(mode="before")
@classmethod
......@@ -320,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif data["top_logprobs"] < 0:
raise ValueError(
"`top_logprobs` must be a value a positive value.")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
......@@ -422,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors(
......@@ -439,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
return SamplingParams(
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
......@@ -458,8 +475,7 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.prompt_logprobs
if self.prompt_logprobs else self.logprobs if self.echo else None,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
......@@ -485,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not data["logprobs"] >= 0:
raise ValueError("if passed, `logprobs` must be a positive value.")
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.")
return data
@model_validator(mode="before")
......@@ -495,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel):
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when stream is true.")
"Stream options can only be defined when `stream=True`.")
return data
......@@ -504,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/embeddings
model: str
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment