Unverified Commit baaedfdb authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable following imports for entrypoints (#7248)


Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarFei <dfdfcai4@gmail.com>
parent 45066412
...@@ -38,7 +38,6 @@ jobs: ...@@ -38,7 +38,6 @@ jobs:
mypy vllm/core --follow-imports skip mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip mypy vllm/model_executor --follow-imports skip
......
...@@ -6,7 +6,7 @@ sphinx-argparse==0.4.0 ...@@ -6,7 +6,7 @@ sphinx-argparse==0.4.0
msgspec msgspec
# packages to install to build the documentation # packages to install to build the documentation
pydantic pydantic >= 2.8
-f https://download.pytorch.org/whl/cpu -f https://download.pytorch.org/whl/cpu
torch torch
py-cpuinfo py-cpuinfo
......
...@@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip ...@@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip mypy vllm/model_executor --follow-imports skip
......
...@@ -56,6 +56,7 @@ files = [ ...@@ -56,6 +56,7 @@ files = [
"vllm/*.py", "vllm/*.py",
"vllm/adapter_commons", "vllm/adapter_commons",
"vllm/assets", "vllm/assets",
"vllm/entrypoints",
"vllm/inputs", "vllm/inputs",
"vllm/logging", "vllm/logging",
"vllm/multimodal", "vllm/multimodal",
......
...@@ -11,7 +11,7 @@ fastapi ...@@ -11,7 +11,7 @@ fastapi
aiohttp aiohttp
openai >= 1.0 # Ensure modern openai package (ensure types module present) openai >= 1.0 # Ensure modern openai package (ensure types module present)
uvicorn[standard] uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server. pydantic >= 2.8 # Required for OpenAI server.
pillow # Required for image processing pillow # Required for image processing
prometheus_client >= 0.18.0 prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0 prometheus-fastapi-instrumentator >= 7.0.0
......
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import re import re
from typing import List from typing import Dict, List, Optional
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
...@@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, ...@@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: Optional[int]):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.chat.completions.create(**params)
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs is not None:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import re import re
import shutil import shutil
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Dict, List from typing import Dict, List, Optional
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
...@@ -268,92 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, ...@@ -268,92 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert len(completion.choices[0].text) >= 0 assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str, prompt_logprobs: int):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.chat.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), @pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0), (MODEL_NAME, 0),
...@@ -361,7 +275,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, ...@@ -361,7 +275,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
(MODEL_NAME, None)]) (MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str, model_name: str,
prompt_logprobs: int): prompt_logprobs: Optional[int]):
params: Dict = { params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"], "prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name, "model": model_name,
...@@ -369,17 +283,12 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, ...@@ -369,17 +283,12 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
if prompt_logprobs is not None: if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs} params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0: if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info: with pytest.raises(BadRequestError):
await client.completions.create(**params) await client.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else: else:
completion = await client.completions.create(**params) completion = await client.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0: if prompt_logprobs is not None:
assert completion.choices[0].prompt_logprobs is not None assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0 assert len(completion.choices[0].prompt_logprobs) > 0
......
...@@ -6,7 +6,6 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, ...@@ -6,7 +6,6 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union) Optional, Set, Tuple, Type, Union)
import torch import torch
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
...@@ -31,6 +30,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -31,6 +30,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
...@@ -427,8 +427,8 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -427,8 +427,8 @@ class _AsyncLLMEngine(LLMEngine):
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> List[int]: ) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`.""" """Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if " tokenizer = self.get_tokenizer_group(
"skip_tokenizer_init is True") missing_msg="prompts must be None if skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id, return await tokenizer.encode_async(request_id=request_id,
prompt=prompt, prompt=prompt,
...@@ -771,7 +771,7 @@ class AsyncLLMEngine: ...@@ -771,7 +771,7 @@ class AsyncLLMEngine:
async def get_tokenizer( async def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
if self.engine_use_ray: if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore return await self.engine.get_tokenizer.remote( # type: ignore
lora_request) lora_request)
......
...@@ -3,9 +3,9 @@ from contextlib import contextmanager ...@@ -3,9 +3,9 @@ from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional) Mapping, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, TypeVar, Union from typing import Set, Tuple, Type, Union
from typing_extensions import assert_never from typing_extensions import TypeVar, assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
...@@ -43,8 +43,9 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, ...@@ -43,8 +43,9 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device from vllm.utils import Counter, Device
...@@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: ...@@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
return config.to_diff_dict() return config.to_diff_dict()
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int], PromptComponents = Tuple[Optional[str], List[int],
...@@ -494,11 +496,20 @@ class LLMEngine: ...@@ -494,11 +496,20 @@ class LLMEngine:
def get_tokenizer_group( def get_tokenizer_group(
self, self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: group_type: Type[_G] = BaseTokenizerGroup,
if self.tokenizer is None: *,
raise ValueError(fail_msg) missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
raise ValueError(missing_msg)
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return self.tokenizer return tokenizer_group
def get_tokenizer( def get_tokenizer(
self, self,
...@@ -693,8 +704,8 @@ class LLMEngine: ...@@ -693,8 +704,8 @@ class LLMEngine:
* prompt token ids * prompt token ids
''' '''
tokenizer = self.get_tokenizer_group("prompts must be None if " tokenizer = self.get_tokenizer_group(
"skip_tokenizer_init is True") missing_msg="prompts must be None if skip_tokenizer_init is True")
return tokenizer.encode(request_id=request_id, return tokenizer.encode(request_id=request_id,
prompt=prompt, prompt=prompt,
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -29,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -29,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: List[Scheduler], scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",
): ):
"""Create an output processor. """Create an output processor.
......
import functools import functools
from typing import Callable, List from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
...@@ -12,6 +10,7 @@ from vllm.sampling_params import SamplingParams ...@@ -12,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -36,7 +35,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -36,7 +35,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: List[Scheduler], scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: StopChecker, stop_checker: StopChecker,
): ):
self.detokenizer = detokenizer self.detokenizer = detokenizer
......
from typing import Callable, Optional from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker: class StopChecker:
...@@ -15,8 +14,7 @@ class StopChecker: ...@@ -15,8 +14,7 @@ class StopChecker:
""" """
def __init__(self, max_model_len: int, def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
PreTrainedTokenizer]):
# Do not use it directly, but use `self._get_max_model_len`. # Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq self.get_tokenizer_for_seq = get_tokenizer_for_seq
......
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
runtime_checkable) runtime_checkable)
from transformers import PreTrainedTokenizer
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs from vllm.inputs.data import PromptInputs
...@@ -12,6 +10,7 @@ from vllm.pooling_params import PoolingParams ...@@ -12,6 +10,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable @runtime_checkable
...@@ -40,6 +39,7 @@ class AsyncEngineClient(Protocol): ...@@ -40,6 +39,7 @@ class AsyncEngineClient(Protocol):
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request""" """Generates outputs for a request"""
...
def encode( def encode(
self, self,
...@@ -50,6 +50,7 @@ class AsyncEngineClient(Protocol): ...@@ -50,6 +50,7 @@ class AsyncEngineClient(Protocol):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.""" """Generate outputs for a request from an embedding model."""
...
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request. """Abort a request.
...@@ -60,25 +61,29 @@ class AsyncEngineClient(Protocol): ...@@ -60,25 +61,29 @@ class AsyncEngineClient(Protocol):
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
...
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine.""" """Get the decoding configuration of the vLLM engine."""
async def get_tokenizer( async def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer: ) -> AnyTokenizer:
"""Get the appropriate Tokenizer for the request""" """Get the appropriate tokenizer for the request"""
...
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
pass ...
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None, model_output: Optional[List[SamplerOutput]] = None,
) -> None: ) -> None:
pass ...
async def check_health(self) -> None: async def check_health(self) -> None:
"""Raise if unhealthy""" """Raise if unhealthy"""
...
...@@ -61,6 +61,7 @@ async def generate(request: Request) -> Response: ...@@ -61,6 +61,7 @@ async def generate(request: Request) -> Response:
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator: async for request_output in results_generator:
prompt = request_output.prompt prompt = request_output.prompt
assert prompt is not None
text_outputs = [ text_outputs = [
prompt + output.text for output in request_output.outputs prompt + output.text for output in request_output.outputs
] ]
...@@ -80,6 +81,7 @@ async def generate(request: Request) -> Response: ...@@ -80,6 +81,7 @@ async def generate(request: Request) -> Response:
assert final_output is not None assert final_output is not None
prompt = final_output.prompt prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs] text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs} ret = {"text": text_outputs}
return JSONResponse(ret) return JSONResponse(ret)
...@@ -115,6 +117,7 @@ async def run_server(args: Namespace, ...@@ -115,6 +117,7 @@ async def run_server(args: Namespace,
logger.info("args: %s", args) logger.info("args: %s", args)
app = await init_app(args, llm_engine) app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
......
...@@ -3,7 +3,7 @@ from dataclasses import dataclass ...@@ -3,7 +3,7 @@ from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union, cast) Union)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -15,9 +15,8 @@ from openai.types.chat import ( ...@@ -15,9 +15,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict from pydantic import ConfigDict, TypeAdapter
from transformers import PreTrainedTokenizer from typing_extensions import Required, TypeAlias, TypedDict
from typing_extensions import Required, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -50,9 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ...@@ -50,9 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part.""" """The type of the content part."""
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, ChatCompletionContentPartParam: TypeAlias = Union[
ChatCompletionContentPartAudioParam, OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam] CustomChatCompletionContentPartParam, ]
class CustomChatCompletionMessageParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False):
...@@ -114,7 +113,7 @@ def load_chat_template( ...@@ -114,7 +113,7 @@ def load_chat_template(
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer, def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]: modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt # TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template) # (similar to chat template)
...@@ -151,11 +150,16 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str, ...@@ -151,11 +150,16 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str,
return f"{placeholder_token_str}\n{text_prompt}" return f"{placeholder_token_str}\n{text_prompt}"
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
role: str, role: str,
parts: Iterable[ChatCompletionContentPartParam], parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> ChatMessageParseResult: ) -> ChatMessageParseResult:
texts: List[str] = [] texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []
...@@ -164,7 +168,7 @@ def _parse_chat_message_content_parts( ...@@ -164,7 +168,7 @@ def _parse_chat_message_content_parts(
for part in parts: for part in parts:
part_type = part["type"] part_type = part["type"]
if part_type == "text": if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"] text = _TextParser.validate_python(part)["text"]
texts.append(text) texts.append(text)
elif part_type == "image_url": elif part_type == "image_url":
modality = "image" modality = "image"
...@@ -172,8 +176,7 @@ def _parse_chat_message_content_parts( ...@@ -172,8 +176,7 @@ def _parse_chat_message_content_parts(
raise NotImplementedError( raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.") "Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam, image_url = _ImageParser.validate_python(part)["image_url"]
part)["image_url"]
if image_url.get("detail", "auto") != "auto": if image_url.get("detail", "auto") != "auto":
logger.warning( logger.warning(
...@@ -188,8 +191,7 @@ def _parse_chat_message_content_parts( ...@@ -188,8 +191,7 @@ def _parse_chat_message_content_parts(
raise NotImplementedError( raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.") "Multiple multimodal inputs is currently not supported.")
audio_url = cast(ChatCompletionContentPartAudioParam, audio_url = _AudioParser.validate_python(part)["audio_url"]
part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"]) audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future) mm_futures.append(audio_future)
else: else:
...@@ -219,7 +221,7 @@ def _parse_chat_message_content_parts( ...@@ -219,7 +221,7 @@ def _parse_chat_message_content_parts(
def _parse_chat_message_content( def _parse_chat_message_content(
message: ChatCompletionMessageParam, message: ChatCompletionMessageParam,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> ChatMessageParseResult: ) -> ChatMessageParseResult:
role = message["role"] role = message["role"]
content = message.get("content") content = message.get("content")
...@@ -230,14 +232,18 @@ def _parse_chat_message_content( ...@@ -230,14 +232,18 @@ def _parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)] messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[]) return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(role, content, model_config, return _parse_chat_message_content_parts(
tokenizer) role,
content, # type: ignore
model_config,
tokenizer,
)
def parse_chat_messages( def parse_chat_messages(
messages: List[ChatCompletionMessageParam], messages: List[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []
......
from contextlib import contextmanager from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from tqdm.auto import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
...@@ -20,7 +19,9 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput ...@@ -20,7 +19,9 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs from vllm.utils import Counter, deprecate_kwargs
...@@ -122,7 +123,7 @@ class LLM: ...@@ -122,7 +123,7 @@ class LLM:
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
seed: int = 0, seed: int = 0,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: float = 4,
cpu_offload_gb: float = 0, cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None, enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
...@@ -175,22 +176,19 @@ class LLM: ...@@ -175,22 +176,19 @@ class LLM:
engine_args, usage_context=UsageContext.LLM_CLASS) engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(self) -> AnyTokenizer:
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
return self.llm_engine.tokenizer.tokenizer
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
# While CachedTokenizer is dynamic, have no choice but # While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from # compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached' # user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"): if tokenizer.__class__.__name__.startswith("Cached"):
self.llm_engine.tokenizer.tokenizer = tokenizer tokenizer_group.tokenizer = tokenizer
else: else:
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
tokenizer)
@overload # LEGACY: single (prompt + optional token ids) @overload # LEGACY: single (prompt + optional token ids)
def generate( def generate(
...@@ -578,6 +576,8 @@ class LLM: ...@@ -578,6 +576,8 @@ class LLM:
inputs: List[PromptInputs] = [] inputs: List[PromptInputs] = []
for i in range(num_requests): for i in range(num_requests):
item: PromptInputs
if prompts is not None: if prompts is not None:
item = TextPrompt(prompt=prompts[i]) item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None: elif prompt_token_ids is not None:
...@@ -635,7 +635,7 @@ class LLM: ...@@ -635,7 +635,7 @@ class LLM:
self, self,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
......
...@@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError ...@@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -29,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser ...@@ -29,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
CompletionResponse,
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
EmbeddingRequest, ErrorResponse, EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
...@@ -90,7 +93,8 @@ async def lifespan(app: FastAPI): ...@@ -90,7 +93,8 @@ async def lifespan(app: FastAPI):
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: async def build_async_engine_client(
args: Namespace) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle # Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
global engine_args global engine_args
...@@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: ...@@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
logger.info("Started engine process with PID %d", logger.info("Started engine process with PID %d",
rpc_server_process.pid) rpc_server_process.pid)
# Build RPCClient, which conforms to AsyncEngineClient Protocol. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(rpc_path) # NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
try: try:
while True: while True:
try: try:
await async_engine_client.setup() await rpc_client.setup()
break break
except TimeoutError as e: except TimeoutError as e:
if not rpc_server_process.is_alive(): if not rpc_server_process.is_alive():
...@@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: ...@@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
rpc_server_process.terminate() rpc_server_process.terminate()
# Close all open connections to the backend # Close all open connections to the backend
async_engine_client.close() rpc_client.close()
# Wait for server process to join # Wait for server process to join
rpc_server_process.join() rpc_server_process.join()
...@@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest): ...@@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, TokenizeResponse):
assert isinstance(generator, TokenizeResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize") @router.post("/detokenize")
async def detokenize(request: DetokenizeRequest): async def detokenize(request: DetokenizeRequest):
...@@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest): ...@@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, DetokenizeResponse):
assert isinstance(generator, DetokenizeResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models") @router.get("/v1/models")
async def show_available_models(): async def show_available_models():
...@@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
if request.stream: elif isinstance(generator, ChatCompletionResponse):
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions") @router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
...@@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
if request.stream: elif isinstance(generator, CompletionResponse):
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings") @router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
...@@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): ...@@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
def build_app(args: Namespace) -> FastAPI: def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
......
...@@ -7,6 +7,7 @@ purposes. ...@@ -7,6 +7,7 @@ purposes.
import argparse import argparse
import json import json
import ssl import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
...@@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser ...@@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action): class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(
lora_list = [] self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: List[LoRAModulePath] = []
for item in values: for item in values:
name, path = item.split('=') name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path)) lora_list.append(LoRAModulePath(name, path))
...@@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action): ...@@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action):
class PromptAdapterParserAction(argparse.Action): class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(
adapter_list = [] self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: List[PromptAdapterPath] = []
for item in values: for item in values:
name, path = item.split('=') name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path)) adapter_list.append(PromptAdapterPath(name, path))
......
...@@ -2,9 +2,9 @@ from functools import lru_cache, partial ...@@ -2,9 +2,9 @@ from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch import torch
from transformers import PreTrainedTokenizer
from vllm.sampling_params import LogitsProcessor from vllm.sampling_params import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
class AllowedTokenIdsLogitsProcessor: class AllowedTokenIdsLogitsProcessor:
...@@ -53,8 +53,9 @@ def logit_bias_logits_processor( ...@@ -53,8 +53,9 @@ def logit_bias_logits_processor(
def get_logits_processors( def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]], allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]: tokenizer: AnyTokenizer,
logits_processors = [] ) -> List[LogitsProcessor]:
logits_processors: List[LogitsProcessor] = []
if logit_bias: if logit_bias:
try: try:
# Convert token_id to integer # Convert token_id to integer
......
...@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Union ...@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
...@@ -14,11 +13,13 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors ...@@ -14,11 +13,13 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
# torch is mocked during docs generation, # torch is mocked during docs generation,
# so we have to provide the values as literals # so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
_LONG_INFO: Union["torch.iinfo", Namespace]
try: try:
from sphinx.ext.autodoc.mock import _MockModule from sphinx.ext.autodoc.mock import _MockModule
...@@ -235,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -235,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_sampling_params( def to_sampling_params(
self, tokenizer: PreTrainedTokenizer, self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor], guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams: default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs. # We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors( logits_processors = get_logits_processors(
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
...@@ -251,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -251,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor: if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor) logits_processors.append(guided_decode_logits_processor)
return SamplingParams( return SamplingParams.from_optional(
n=self.n, n=self.n,
best_of=self.best_of, best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
...@@ -265,8 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -265,8 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None, logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else prompt_logprobs=prompt_logprobs,
(self.top_logprobs if self.echo else None),
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
...@@ -280,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -280,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
) )
@model_validator(mode='before') @model_validator(mode="before")
@classmethod @classmethod
def validate_stream_options(cls, values): def validate_stream_options(cls, data):
if (values.get('stream_options') is not None if data.get("stream_options") and not data.get("stream"):
and not values.get('stream')):
raise ValueError( raise ValueError(
"stream_options can only be set if stream is true") "Stream options can only be defined when `stream=True`.")
return values
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
...@@ -320,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -320,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.") "When using `tool_choice`, `tools` must be set.")
return data return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif data["top_logprobs"] < 0:
raise ValueError(
"`top_logprobs` must be a value a positive value.")
return data
class CompletionRequest(OpenAIBaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
...@@ -422,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -422,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params # doc: end-completion-extra-params
def to_sampling_params( def to_sampling_params(
self, tokenizer: PreTrainedTokenizer, self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor], guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams: default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors( logits_processors = get_logits_processors(
...@@ -439,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -439,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor: if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor) logits_processors.append(guided_decode_logits_processor)
return SamplingParams( return SamplingParams.from_optional(
n=self.n, n=self.n,
best_of=self.best_of, best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
...@@ -458,8 +475,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -458,8 +475,7 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,
prompt_logprobs=self.prompt_logprobs prompt_logprobs=prompt_logprobs,
if self.prompt_logprobs else self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
...@@ -485,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -485,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_logprobs(cls, data): def check_logprobs(cls, data):
if "logprobs" in data and data[ if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
"logprobs"] is not None and not data["logprobs"] >= 0: if data.get("stream") and prompt_logprobs > 0:
raise ValueError("if passed, `logprobs` must be a positive value.") raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.")
return data return data
@model_validator(mode="before") @model_validator(mode="before")
...@@ -495,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -495,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel):
def validate_stream_options(cls, data): def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError( raise ValueError(
"Stream options can only be defined when stream is true.") "Stream options can only be defined when `stream=True`.")
return data return data
...@@ -504,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel): ...@@ -504,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/embeddings # https://platform.openai.com/docs/api-reference/embeddings
model: str model: str
input: Union[List[int], List[List[int]], str, List[str]] input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None dimensions: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment