Unverified Commit 7eb4a51c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Support serving encoder/decoder models (#7258)

parent 0fa14907
...@@ -25,7 +25,7 @@ jobs: ...@@ -25,7 +25,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install mypy==1.9.0 pip install mypy==1.11.1
pip install types-setuptools pip install types-setuptools
pip install types-PyYAML pip install types-PyYAML
pip install types-requests pip install types-requests
......
...@@ -4,8 +4,8 @@ encoder/decoder models, specifically BART ...@@ -4,8 +4,8 @@ encoder/decoder models, specifically BART
''' '''
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
from vllm.utils import zip_enc_dec_prompt_lists TokensPrompt, zip_enc_dec_prompts)
dtype = "float" dtype = "float"
...@@ -61,9 +61,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( ...@@ -61,9 +61,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
) )
# - Finally, here's a useful helper function for zipping encoder and # - Finally, here's a useful helper function for zipping encoder and
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt # decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances # instances
zipped_prompt_list = zip_enc_dec_prompt_lists( zipped_prompt_list = zip_enc_dec_prompts(
['An encoder prompt', 'Another encoder prompt'], ['An encoder prompt', 'Another encoder prompt'],
['A decoder prompt', 'Another decoder prompt']) ['A decoder prompt', 'Another decoder prompt'])
......
...@@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 ...@@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3 lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq pyzmq
gguf == 0.9.1 gguf == 0.9.1
...@@ -8,7 +8,7 @@ isort==5.13.2 ...@@ -8,7 +8,7 @@ isort==5.13.2
clang-format==18.1.5 clang-format==18.1.5
# type checking # type checking
mypy==1.9.0 mypy==1.11.1
types-PyYAML types-PyYAML
types-requests types-requests
types-setuptools types-setuptools
...@@ -3,6 +3,7 @@ import gc ...@@ -3,6 +3,7 @@ import gc
import os import os
import sys import sys
from collections import UserList from collections import UserList
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
import pytest import pytest
...@@ -14,20 +15,19 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, ...@@ -14,20 +15,19 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding, AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
BatchFeature) BatchFeature)
from tests.models.utils import DecoderPromptType
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment, from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel) destroy_model_parallel)
from vllm.inputs import TextPrompt from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu, to_enc_dec_tuple_list, is_cpu)
zip_enc_dec_prompt_lists)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -124,10 +124,16 @@ def example_prompts() -> List[str]: ...@@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
return prompts return prompts
class DecoderPromptType(Enum):
"""For encoder/decoder models only."""
CUSTOM = 1
NONE = 2
EMPTY_STR = 3
@pytest.fixture @pytest.fixture
def example_encoder_decoder_prompts() \ def example_encoder_decoder_prompts(
-> Dict[DecoderPromptType, ) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
Tuple[List[str], List[Optional[str]]]]:
''' '''
Returns an encoder prompt list and a decoder prompt list, wherein each pair Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt, of same-index entries in both lists corresponds to an (encoder prompt,
...@@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \ ...@@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
# NONE decoder prompt type # NONE decoder prompt type
return { return {
DecoderPromptType.NONE: DecoderPromptType.NONE:
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts), zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
DecoderPromptType.EMPTY_STR: DecoderPromptType.EMPTY_STR:
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts), zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
DecoderPromptType.CUSTOM: DecoderPromptType.CUSTOM:
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts), zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
} }
...@@ -444,7 +450,7 @@ class HfRunner: ...@@ -444,7 +450,7 @@ class HfRunner:
def generate_encoder_decoder_greedy_logprobs_limit( def generate_encoder_decoder_greedy_logprobs_limit(
self, self,
encoder_decoder_prompts: Tuple[List[str], List[str]], encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
**kwargs: Any, **kwargs: Any,
...@@ -608,7 +614,7 @@ class VllmRunner: ...@@ -608,7 +614,7 @@ class VllmRunner:
def generate_encoder_decoder_w_logprobs( def generate_encoder_decoder_w_logprobs(
self, self,
encoder_decoder_prompts: Tuple[List[str], List[str]], encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams, sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
''' '''
...@@ -653,7 +659,7 @@ class VllmRunner: ...@@ -653,7 +659,7 @@ class VllmRunner:
def generate_encoder_decoder_greedy_logprobs( def generate_encoder_decoder_greedy_logprobs(
self, self,
encoder_decoder_prompts: Tuple[List[str], List[str]], encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
......
...@@ -11,9 +11,9 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py ...@@ -11,9 +11,9 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
import pytest import pytest
from tests.models.utils import DecoderPromptType
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close from ..models.utils import check_logprobs_close
from ..utils import fork_new_process_for_each_test from ..utils import fork_new_process_for_each_test
......
import openai
import pytest
from ...utils import RemoteOpenAIServer
MODEL_NAME = "facebook/bart-base"
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--enforce-eager",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=2, total_tokens=7)
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
Run `pytest tests/models/test_bart.py`. Run `pytest tests/models/test_bart.py`.
""" """
from typing import List, Optional, Tuple
from vllm.utils import is_cpu from vllm.utils import is_cpu
if not is_cpu(): if not is_cpu():
...@@ -11,22 +13,31 @@ if not is_cpu(): ...@@ -11,22 +13,31 @@ if not is_cpu():
import pytest import pytest
from tests.models.utils import DecoderPromptType from vllm.sequence import SampleLogprobs
from ..conftest import DecoderPromptType
from .utils import check_logprobs_close from .utils import check_logprobs_close
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
DECODER_PROMPT_TYPES = ([ def vllm_to_hf_output(
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR, vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
DecoderPromptType.NONE decoder_prompt_type: DecoderPromptType,
]) ):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "</s>"
if decoder_prompt_type == DecoderPromptType.NONE:
hf_output_str = "<s>" + hf_output_str
return output_ids, hf_output_str, out_logprobs
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models( def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
...@@ -146,8 +157,13 @@ if not is_cpu(): ...@@ -146,8 +157,13 @@ if not is_cpu():
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0) else 0)
check_logprobs_close(outputs_0_lst=hf_outputs, check_logprobs_close(
outputs_1_lst=vllm_outputs, outputs_0_lst=hf_outputs,
name_0="hf", outputs_1_lst=[
name_1="vllm", vllm_to_hf_output(vllm_output, decoder_prompt_type)
num_outputs_0_skip_tokens=hf_skip_tokens) for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
import warnings import warnings
from enum import Enum
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
...@@ -136,13 +135,3 @@ def check_logprobs_close( ...@@ -136,13 +135,3 @@ def check_logprobs_close(
warnings.simplefilter("always") warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2) warnings.warn(fail_msg, stacklevel=2)
class DecoderPromptType(Enum):
'''
For encoder/decoder models only -
'''
CUSTOM = 1
NONE = 2
EMPTY_STR = 3
...@@ -2,7 +2,7 @@ from typing import List ...@@ -2,7 +2,7 @@ from typing import List
import pytest import pytest
from vllm.inputs import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
STRING_INPUTS = [ STRING_INPUTS = [
'', '',
......
...@@ -464,6 +464,16 @@ class ModelConfig: ...@@ -464,6 +464,16 @@ class ModelConfig:
if t != "attention" if t != "attention"
]) ])
@property
def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False)
@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache. """Configuration for the KV cache.
......
...@@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, ...@@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union) Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
...@@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ...@@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents)
from vllm.engine.metrics import StatLoggerBase from vllm.engine.metrics import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
...@@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop.""" """Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async() await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async( async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _extract_prompt_components_async(
self, self,
inputs: SingletonPromptInputs,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs, inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
if isinstance(inputs, str): """Async version of :meth:`_process_decoder_only_prompt`."""
inputs = {"prompt": inputs} prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
if "prompt_token_ids" not in inputs: return self._build_decoder_only_llm_inputs(
tokenizer = self.get_tokenizer_group("prompts must be None if " prompt_comps,
"skip_tokenizer_init is True") prompt_adapter_request=prompt_adapter_request,
)
prompt_token_ids = await tokenizer.encode_async( async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id, request_id=request_id,
prompt=inputs["prompt"], )
lora_request=lora_request)
else: else:
prompt_token_ids = inputs["prompt_token_ids"] if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
if prompt_adapter_request: # Decoder-only operation
prompt_token_ids = [ model_inputs = await self._process_decoder_only_prompt_async(
0 inputs,
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ request_id=request_id,
prompt_token_ids lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, )
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
return self.input_processor(llm_inputs) return self.input_processor(model_inputs)
async def add_request_async( async def add_request_async(
self, self,
...@@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
"""Async version of :meth:`add_request`."""
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
...@@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time() arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async( processed_inputs = await self.process_model_inputs_async(
inputs,
request_id=request_id, request_id=request_id,
inputs=inputs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
......
...@@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, ...@@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, TypeVar, Union from typing import Set, Tuple, Type, TypeVar, Union
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
...@@ -22,10 +24,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker ...@@ -22,10 +24,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs, from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs,
get_prompt_type) PromptInputs, SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -43,8 +47,7 @@ from vllm.transformers_utils.tokenizer_group import ( ...@@ -43,8 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import (Counter, is_embedding_model_config, from vllm.utils import Counter
is_encoder_decoder_model_config)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: ...@@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int],
Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional[MultiModalDataDict]]
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
...@@ -524,7 +532,7 @@ class LLMEngine: ...@@ -524,7 +532,7 @@ class LLMEngine:
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _get_decoder_start_token_id(self, ) -> Optional[int]: def _get_decoder_start_token_id(self) -> Optional[int]:
''' '''
Obtain the decoder start token id employed by an encoder/decoder Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the model. Returns None for non-encoder/decoder models or if the
...@@ -553,7 +561,7 @@ class LLMEngine: ...@@ -553,7 +561,7 @@ class LLMEngine:
def _add_processed_request( def _add_processed_request(
self, self,
request_id: str, request_id: str,
processed_inputs: LLMInputs, processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
...@@ -613,11 +621,11 @@ class LLMEngine: ...@@ -613,11 +621,11 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
_LLMInputComponentsType = Tuple[str, List[int], ] _LLMInputComponentsType = Tuple[str, List[int]]
def _prepare_decoder_input_ids_for_generation( def _prepare_decoder_input_ids_for_generation(
self, self,
decoder_input_ids: Optional[List[int]] = None, decoder_input_ids: Optional[List[int]],
) -> List[int]: ) -> List[int]:
""" """
Prepares `decoder_input_ids` for generation with encoder-decoder models. Prepares `decoder_input_ids` for generation with encoder-decoder models.
...@@ -639,14 +647,13 @@ class LLMEngine: ...@@ -639,14 +647,13 @@ class LLMEngine:
* Processed token list * Processed token list
""" """
decoder_start_token_id: Optional[int] = ( decoder_start_token_id = self._get_decoder_start_token_id()
self._get_decoder_start_token_id())
assert decoder_start_token_id is not None assert decoder_start_token_id is not None
if decoder_input_ids is None: if decoder_input_ids is None:
# no decoder prompt input -> # no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids # use decoder_start_token_id as decoder_input_ids
(decoder_input_ids) = self._get_default_enc_dec_decoder_prompt() decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0 if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id): or decoder_input_ids[0] != decoder_start_token_id):
...@@ -657,12 +664,11 @@ class LLMEngine: ...@@ -657,12 +664,11 @@ class LLMEngine:
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: str,
lora_request: Optional[str] = None, lora_request: Optional[LoRARequest],
) -> List[int]: ) -> List[int]:
''' '''
Wrapper around application of the model's Wrapper around application of the model's tokenizer.
tokenizer.
Arguments: Arguments:
...@@ -678,87 +684,72 @@ class LLMEngine: ...@@ -678,87 +684,72 @@ class LLMEngine:
tokenizer = self.get_tokenizer_group("prompts must be None if " tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True") "skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id, return tokenizer.encode(request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request) lora_request=lora_request)
return prompt_token_ids
def _extract_single_prompt_for_enc_dec_input( def _extract_prompt_components(
self, self,
inputs: Optional[PromptInputs], inputs: SingletonPromptInputs,
request_id: Optional[str] = None, request_id: str,
ptype: Optional[str] = None, lora_request: Optional[LoRARequest] = None,
is_encoder_prompt: bool = False, ) -> PromptComponents:
) -> Tuple[Optional[str], List[int]]:
''' '''
Only for encoder/decoder models: Extract the components of any single encoder or decoder input prompt.
Extract prompt & prompt_token_ids from any single
encoder or decoder input prompt. For encoder input prompts
in particular, also extract multi-modal data.
This function handles the following scenarios:
1. The user supplied a singleton encoder prompt
& the prompt/prompt-token-ids must be extracted.
2. The user supplied an explicit encoder/decoder
prompt & the prompt/prompt-token-ids must be
extracted from either the encoder and decoder prompts.
For decoder prompts in particular (scenario 2), special
processing is applied to the returned decoder token ids.
Arguments: Arguments:
* request_id * request_id
* ptype: str representation of the input prompt type.
If `ptype` is `None`, assume that the prompt
type is unknown and must be inferred. This is the
case for ExplicitEncoderDecoder sub-prompts.
* inputs: single encoder or decoder input prompt * inputs: single encoder or decoder input prompt
* is_encoder_prompt: True if encoder input prompt. * lora_request: this is only valid for decoder prompts
If False, decoder prompt tokens
are preprocessed.
Returns: Returns:
* prompt * prompt
* prompt_token_ids * prompt_token_ids
* multi_modal_data
''' '''
prompt_token_ids = None
ptype = (get_prompt_type(inputs) if ptype is None else ptype)
if inputs is None: if isinstance(inputs, str):
prompt = None
elif ptype == 'str':
prompt = inputs prompt = inputs
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request,
) )
elif ptype == 'TokensPrompt': multi_modal_data = None
prompt = None elif isinstance(inputs, dict):
prompt_token_ids = inputs['prompt_token_ids'] if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = self._tokenize_prompt(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else: else:
prompt = inputs['prompt'] assert_never(inputs)
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
)
if not is_encoder_prompt: return prompt, prompt_token_ids, multi_modal_data
# Apply special pre-processing to
# decoder prompts
prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
prompt_token_ids, ))
assert prompt_token_ids is not None def _apply_prompt_adapter(
self,
prompt_token_ids: List[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> List[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return ( return prompt_token_ids
prompt,
prompt_token_ids,
)
def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]: def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
''' '''
Specifically for encoder/decoder models: Specifically for encoder/decoder models:
generate a default decoder prompt for when generate a default decoder prompt for when
...@@ -792,18 +783,39 @@ class LLMEngine: ...@@ -792,18 +783,39 @@ class LLMEngine:
bos_token_id = self._get_bos_token_id() bos_token_id = self._get_bos_token_id()
assert bos_token_id is not None assert bos_token_id is not None
prompt_token_ids: List[int] = [bos_token_id] return [bos_token_id]
return prompt_token_ids
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
)
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
inputs: PromptInputs, inputs: PromptInputs,
request_id: Optional[str] = None, request_id: str,
) -> LLMInputs: ) -> EncoderDecoderLLMInputs:
''' '''
For encoder/decoder models only: For encoder/decoder models only:
Process an input prompt Process an input prompt into an
into an `LLMInputs` instance. :class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts: There are two types of input prompts:
singleton prompts which carry only the singleton prompts which carry only the
...@@ -830,136 +842,103 @@ class LLMEngine: ...@@ -830,136 +842,103 @@ class LLMEngine:
Returns: Returns:
* `LLMInputs` instance * :class:`EncoderDecoderLLMInputs` instance
''' '''
ptype = get_prompt_type(inputs) encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
# Obtain encoder and decoder prompt tokens. Note
# that, no matter what, the decoder if is_explicit_encoder_decoder_prompt(inputs):
# prompt type is unknown. encoder_comps = self._extract_prompt_components(
if ptype == "ExplicitEncoderDecoder": inputs["encoder_prompt"],
# If input is explicit encoder/decoder prompt, request_id=request_id,
# then it remains to be determined what type )
# of encoder prompt we have
extracted_encoder_prompt = inputs.get('encoder_prompt') if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_ptype = None decoder_comps = None, None, None
# Extract decoder prompt from explicit else:
# encoder/decoder prompt decoder_comps = self._extract_prompt_components(
extracted_decoder_prompt = inputs.get('decoder_prompt') decoder_input,
request_id=request_id,
)
else: else:
# If input is singleton encoder prompt, then encoder_comps = self._extract_prompt_components(
# we know the encoder prompt type inputs,
extracted_encoder_prompt = inputs request_id=request_id,
encoder_ptype = ptype )
# Decoder prompt is always unknown if
# encoder/decoder prompt is not explicit
extracted_decoder_prompt = None
# Invoke helper function to obtain encoder
# prompt and prompt token ids, either from
# singleton encoder prompt or from the
# encoder sub-prompt of an explicit
# encoder/decode scenario 2), special
# processing is applied to the returned decoder token ids
(
encoder_prompt,
encoder_prompt_token_ids,
) = self._extract_single_prompt_for_enc_dec_input(
extracted_encoder_prompt,
request_id=request_id,
ptype=encoder_ptype,
is_encoder_prompt=True,
)
# Invoke helper method to obtain decoder_comps = None, None, None
# decoder prompt and prompt token ids.
#
# The helper method will detect the decoder
# prompt type.
#
# Helper method will also apply special
# preprocessing unique to decoder prompts.
(
decoder_prompt,
decoder_prompt_token_ids,
) = self._extract_single_prompt_for_enc_dec_input(
extracted_decoder_prompt,
request_id=request_id,
ptype=None,
is_encoder_prompt=False,
)
return LLMInputs( return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
prompt_token_ids=decoder_prompt_token_ids,
prompt=decoder_prompt, def _build_decoder_only_llm_inputs(
encoder_prompt_token_ids=encoder_prompt_token_ids, self,
encoder_prompt=encoder_prompt, prompt_comps: PromptComponents,
) prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
inputs: PromptInputs, inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
request_id: Optional[str] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
''' '''
For decoder-only models: For decoder-only models:
Process an input prompt Process an input prompt into an :class:`LLMInputs` instance.
into an `LLMInputs` instance.
Arguments: Arguments:
* inputs: input prompt * inputs: input prompt
* lora_request
* request_id * request_id
* lora_request
* prompt_adapter_request * prompt_adapter_request
Returns: Returns:
* `LLMInputs` instance * :class:`LLMInputs` instance
''' '''
if isinstance(inputs, str): prompt_comps = self._extract_prompt_components(
inputs = {"prompt": inputs} inputs,
prompt = inputs.get("prompt") request_id=request_id,
lora_request=lora_request,
if "prompt_token_ids" not in inputs: )
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return LLMInputs(prompt_token_ids=prompt_token_ids, return self._build_decoder_only_llm_inputs(
prompt=prompt, prompt_comps,
multi_modal_data=inputs.get("multi_modal_data")) prompt_adapter_request=prompt_adapter_request,
)
def process_model_inputs( def process_model_inputs(
self, self,
request_id: str,
inputs: PromptInputs, inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
if self.is_encoder_decoder_model(): if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
model_inputs = self._process_encoder_decoder_prompt( model_inputs = self._process_encoder_decoder_prompt(
inputs, inputs,
request_id=request_id, request_id=request_id,
) )
else: else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation # Decoder-only operation
model_inputs = self._process_decoder_only_prompt( model_inputs = self._process_decoder_only_prompt(
inputs, inputs,
...@@ -1029,10 +1008,11 @@ class LLMEngine: ...@@ -1029,10 +1008,11 @@ class LLMEngine:
arrival_time = time.time() arrival_time = time.time()
processed_inputs = self.process_model_inputs( processed_inputs = self.process_model_inputs(
inputs,
request_id=request_id, request_id=request_id,
inputs=inputs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -1597,7 +1577,7 @@ class LLMEngine: ...@@ -1597,7 +1577,7 @@ class LLMEngine:
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
def is_encoder_decoder_model(self): def is_encoder_decoder_model(self):
return is_encoder_decoder_model_config(self.model_config) return self.model_config.is_encoder_decoder_model
def is_embedding_model(self): def is_embedding_model(self):
return is_embedding_model_config(self.model_config) return self.model_config.is_embedding_model
...@@ -2,8 +2,7 @@ import codecs ...@@ -2,8 +2,7 @@ import codecs
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union, from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
cast, final)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -59,7 +58,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, ...@@ -59,7 +58,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam] CustomChatCompletionMessageParam]
@final # So that it should be compatible with Dict[str, str] # TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict): class ConversationMessage(TypedDict):
role: str role: str
content: str content: str
......
...@@ -6,8 +6,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast ...@@ -6,8 +6,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
parse_and_batch_prompt) from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
......
...@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor( ...@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
return AllowedTokenIdsLogitsProcessor(allowed_token_ids) return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(logit_bias: Dict[str, def logit_bias_logits_processor(
float], token_ids: List[int], logit_bias: Dict[int, float],
logits: torch.Tensor) -> torch.Tensor: token_ids: List[int],
logits: torch.Tensor,
) -> torch.Tensor:
for token_id, bias in logit_bias.items(): for token_id, bias in logit_bias.items():
logits[token_id] += bias logits[token_id] += bias
return logits return logits
......
...@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeRequest) TokenizeRequest)
# yapf: enable # yapf: enable
from vllm.inputs import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
......
from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText, from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
ParsedTokens, PromptInputs, SingletonPromptInputs, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TextPrompt, TokensPrompt, get_prompt_type, TokensPrompt, build_explicit_enc_dec_prompt,
is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt) to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
...@@ -14,18 +14,17 @@ See also: ...@@ -14,18 +14,17 @@ See also:
""" """
__all__ = [ __all__ = [
"ParsedText",
"ParsedTokens",
"parse_and_batch_prompt",
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"PromptInputs", "PromptInputs",
"SingletonPromptInputs",
"ExplicitEncoderDecoderPrompt",
"LLMInputs", "LLMInputs",
"EncoderDecoderLLMInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"INPUT_REGISTRY", "INPUT_REGISTRY",
"InputContext", "InputContext",
"InputRegistry", "InputRegistry",
"get_prompt_type",
"is_valid_encoder_decoder_llm_inputs",
"ExplicitEncoderDecoderPrompt",
"SingletonPromptInputs",
] ]
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
TypedDict, Union, cast, overload) Union)
from typing_extensions import NotRequired from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict): class TextPrompt(TypedDict):
"""Schema for a text prompt.""" """Schema for a text prompt."""
...@@ -103,39 +44,49 @@ Note that "singleton" is as opposed to a data structure ...@@ -103,39 +44,49 @@ Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder the user desires to express both the encoder & decoder
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type SingletonPromptInputs may be employed A prompt of type :class:`SingletonPromptInputs` may be employed
as (1) input to a decoder-only model, (2) input to as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating (3) as a member of a larger data structure encapsulating
more than one prompt, i.e. ExplicitEncoderDecoderPrompt more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
""" """
_T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
_T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
class ExplicitEncoderDecoderPrompt(TypedDict):
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt, """Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a comprising an explicit encoder prompt and a
decoder prompt. decoder prompt.
The encoder and decoder prompts, respectively, The encoder and decoder prompts, respectively,
may formatted according to any of the may formatted according to any of the
SingletonPromptInputs schemas, and are not :class:`SingletonPromptInputs` schemas, and are not
required to have the same schema. required to have the same schema.
Only the encoder prompt may have multi-modal data. Only the encoder prompt may have multi-modal data.
Note that an ExplicitEncoderDecoderPrompt may not Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model, be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt` and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure may not themselves fields of this data structure themselves must be
must be SingletonPromptInputs instances. :class:`SingletonPromptInputs` instances.
""" """
encoder_prompt: SingletonPromptInputs encoder_prompt: _T1_co
decoder_prompt: SingletonPromptInputs decoder_prompt: Optional[_T2_co]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
...@@ -150,60 +101,12 @@ both decoder-only and encoder/decoder input types: ...@@ -150,60 +101,12 @@ both decoder-only and encoder/decoder input types:
""" """
def _has_required_keys(
d: dict,
required_keys: set,
) -> bool:
return required_keys.issubset(d.keys())
def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
"""
Get the type-name of the prompt argument instance, given that
isinstance() cannot apply to TypedDict subclasses directly.
If the prompt is None, return 'None' as the type name.
Arguments:
* prompt: LLM input prompt or None
Returns:
* String representation of prompt type
"""
if prompt is None:
return 'None'
required_keys_dict = {
'TextPrompt': {'prompt'},
'TokensPrompt': {'prompt_token_ids'},
'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
}
if isinstance(prompt, dict):
for (ptype, required_keys) in required_keys_dict.items():
# Ignore type checking in the conditional below because type
# checker does not understand that is_dict(prompt) narrows
# down the possible types
if _has_required_keys(
prompt, # type: ignore
required_keys):
return ptype
raise ValueError(f"Invalid prompt {prompt}, valid types are "
"required_keys_dict={required_keys_dict}")
if isinstance(prompt, str):
return "str"
raise ValueError(f"Invalid prompt {prompt}")
class LLMInputs(TypedDict): class LLMInputs(TypedDict):
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor. passed to the model executor.
This specifies the data required for decoder-only models.
""" """
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
...@@ -213,7 +116,21 @@ class LLMInputs(TypedDict): ...@@ -213,7 +116,21 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available. The original prompt text corresponding to the token IDs, if available.
""" """
encoder_prompt_token_ids: NotRequired[List[int]] multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class EncoderDecoderLLMInputs(LLMInputs):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder_prompt_token_ids: List[int]
"""The token IDs of the encoder prompt.""" """The token IDs of the encoder prompt."""
encoder_prompt: NotRequired[Optional[str]] encoder_prompt: NotRequired[Optional[str]]
...@@ -222,20 +139,40 @@ class LLMInputs(TypedDict): ...@@ -222,20 +139,40 @@ class LLMInputs(TypedDict):
available. available.
""" """
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
""" _T1 = TypeVar("_T1",
Optional multi-modal data to pass to the model, bound=SingletonPromptInputs,
if the model supports it. default=SingletonPromptInputs)
""" _T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool: def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: Optional[_T2],
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[Optional[_T2]],
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
""" """
Return True if the LLMInputs instance has the correct configuration Zip encoder and decoder prompts together into a list of
for encoder/decoder. :class:`ExplicitEncoderDecoderPrompt` instances.
""" """
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
]
# True if encoder prompt token ids field exists & def to_enc_dec_tuple_list(
# is not None enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
return ('encoder_prompt_token_ids' in inputs ) -> List[Tuple[_T1, Optional[_T2]]]:
and inputs['encoder_prompt_token_ids'] is not None) return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts]
from typing import List, Literal, Sequence, TypedDict, Union, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs)
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt, str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
if is_list_of(prompt, int):
# case 3: array of tokens
return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt[0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in prompt
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs
def is_valid_encoder_decoder_llm_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
) -> TypeIs[EncoderDecoderLLMInputs]:
return "encoder_prompt_token_ids" in inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment