Unverified Commit 7eb4a51c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Support serving encoder/decoder models (#7258)

parent 0fa14907
......@@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install mypy==1.11.1
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
......
......@@ -4,8 +4,8 @@ encoder/decoder models, specifically BART
'''
from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.utils import zip_enc_dec_prompt_lists
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
TokensPrompt, zip_enc_dec_prompts)
dtype = "float"
......@@ -61,9 +61,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
)
# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list = zip_enc_dec_prompt_lists(
zipped_prompt_list = zip_enc_dec_prompts(
['An encoder prompt', 'Another encoder prompt'],
['A decoder prompt', 'Another decoder prompt'])
......
......@@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
gguf == 0.9.1
......@@ -8,7 +8,7 @@ isort==5.13.2
clang-format==18.1.5
# type checking
mypy==1.9.0
mypy==1.11.1
types-PyYAML
types-requests
types-setuptools
......@@ -3,6 +3,7 @@ import gc
import os
import sys
from collections import UserList
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
import pytest
......@@ -14,20 +15,19 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
BatchFeature)
from tests.models.utils import DecoderPromptType
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu, to_enc_dec_tuple_list,
zip_enc_dec_prompt_lists)
is_cpu)
logger = init_logger(__name__)
......@@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
return prompts
class DecoderPromptType(Enum):
"""For encoder/decoder models only."""
CUSTOM = 1
NONE = 2
EMPTY_STR = 3
@pytest.fixture
def example_encoder_decoder_prompts() \
-> Dict[DecoderPromptType,
Tuple[List[str], List[Optional[str]]]]:
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
'''
Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt,
......@@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
# NONE decoder prompt type
return {
DecoderPromptType.NONE:
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
DecoderPromptType.EMPTY_STR:
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
DecoderPromptType.CUSTOM:
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
}
......@@ -444,7 +450,7 @@ class HfRunner:
def generate_encoder_decoder_greedy_logprobs_limit(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
......@@ -608,7 +614,7 @@ class VllmRunner:
def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
'''
......@@ -653,7 +659,7 @@ class VllmRunner:
def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
......
......@@ -11,9 +11,9 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
import pytest
from tests.models.utils import DecoderPromptType
from vllm.utils import cuda_device_count_stateless
from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close
from ..utils import fork_new_process_for_each_test
......
import openai
import pytest
from ...utils import RemoteOpenAIServer
MODEL_NAME = "facebook/bart-base"
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--enforce-eager",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=2, total_tokens=7)
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
......@@ -2,6 +2,8 @@
Run `pytest tests/models/test_bart.py`.
"""
from typing import List, Optional, Tuple
from vllm.utils import is_cpu
if not is_cpu():
......@@ -11,22 +13,31 @@ if not is_cpu():
import pytest
from tests.models.utils import DecoderPromptType
from vllm.sequence import SampleLogprobs
from ..conftest import DecoderPromptType
from .utils import check_logprobs_close
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
DECODER_PROMPT_TYPES = ([
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
DecoderPromptType.NONE
])
def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "</s>"
if decoder_prompt_type == DecoderPromptType.NONE:
hf_output_str = "<s>" + hf_output_str
return output_ids, hf_output_str, out_logprobs
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models(
hf_runner,
vllm_runner,
......@@ -146,8 +157,13 @@ if not is_cpu():
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)
check_logprobs_close(outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
import warnings
from enum import Enum
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs
......@@ -136,13 +135,3 @@ def check_logprobs_close(
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
class DecoderPromptType(Enum):
'''
For encoder/decoder models only -
'''
CUSTOM = 1
NONE = 2
EMPTY_STR = 3
......@@ -2,7 +2,7 @@ from typing import List
import pytest
from vllm.inputs import parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt
STRING_INPUTS = [
'',
......
......@@ -464,6 +464,16 @@ class ModelConfig:
if t != "attention"
])
@property
def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False)
@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode
class CacheConfig:
"""Configuration for the KV cache.
......
......@@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
......@@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents)
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
......@@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async(
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
prompt_token_ids = await tokenizer.encode_async(
async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
)
else:
prompt_token_ids = inputs["prompt_token_ids"]
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
if prompt_adapter_request:
prompt_token_ids = [
0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
return self.input_processor(llm_inputs)
return self.input_processor(model_inputs)
async def add_request_async(
self,
......@@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Async version of :meth:`add_request`."""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
......@@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async(
inputs,
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
request_id=request_id,
......
......@@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, TypeVar, Union
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
......@@ -22,10 +24,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
get_prompt_type)
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs,
PromptInputs, SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
......@@ -43,8 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (Counter, is_embedding_model_config,
is_encoder_decoder_model_config)
from vllm.utils import Counter
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
......@@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int],
Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional[MultiModalDataDict]]
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
......@@ -524,7 +532,7 @@ class LLMEngine:
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _get_decoder_start_token_id(self, ) -> Optional[int]:
def _get_decoder_start_token_id(self) -> Optional[int]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
......@@ -553,7 +561,7 @@ class LLMEngine:
def _add_processed_request(
self,
request_id: str,
processed_inputs: LLMInputs,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
......@@ -613,11 +621,11 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
_LLMInputComponentsType = Tuple[str, List[int], ]
_LLMInputComponentsType = Tuple[str, List[int]]
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]] = None,
decoder_input_ids: Optional[List[int]],
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
......@@ -639,14 +647,13 @@ class LLMEngine:
* Processed token list
"""
decoder_start_token_id: Optional[int] = (
self._get_decoder_start_token_id())
decoder_start_token_id = self._get_decoder_start_token_id()
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
(decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
......@@ -657,12 +664,11 @@ class LLMEngine:
def _tokenize_prompt(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[str] = None,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
'''
Wrapper around application of the model's
tokenizer.
Wrapper around application of the model's tokenizer.
Arguments:
......@@ -678,87 +684,72 @@ class LLMEngine:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
def _extract_single_prompt_for_enc_dec_input(
def _extract_prompt_components(
self,
inputs: Optional[PromptInputs],
request_id: Optional[str] = None,
ptype: Optional[str] = None,
is_encoder_prompt: bool = False,
) -> Tuple[Optional[str], List[int]]:
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
'''
Only for encoder/decoder models:
Extract prompt & prompt_token_ids from any single
encoder or decoder input prompt. For encoder input prompts
in particular, also extract multi-modal data.
This function handles the following scenarios:
1. The user supplied a singleton encoder prompt
& the prompt/prompt-token-ids must be extracted.
2. The user supplied an explicit encoder/decoder
prompt & the prompt/prompt-token-ids must be
extracted from either the encoder and decoder prompts.
For decoder prompts in particular (scenario 2), special
processing is applied to the returned decoder token ids.
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* ptype: str representation of the input prompt type.
If `ptype` is `None`, assume that the prompt
type is unknown and must be inferred. This is the
case for ExplicitEncoderDecoder sub-prompts.
* inputs: single encoder or decoder input prompt
* is_encoder_prompt: True if encoder input prompt.
If False, decoder prompt tokens
are preprocessed.
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
prompt_token_ids = None
ptype = (get_prompt_type(inputs) if ptype is None else ptype)
if inputs is None:
prompt = None
elif ptype == 'str':
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
elif ptype == 'TokensPrompt':
prompt = None
prompt_token_ids = inputs['prompt_token_ids']
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = self._tokenize_prompt(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
prompt = inputs['prompt']
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
)
assert_never(inputs)
if not is_encoder_prompt:
# Apply special pre-processing to
# decoder prompts
prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
prompt_token_ids, ))
return prompt, prompt_token_ids, multi_modal_data
assert prompt_token_ids is not None
def _apply_prompt_adapter(
self,
prompt_token_ids: List[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> List[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return (
prompt,
prompt_token_ids,
)
return prompt_token_ids
def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
......@@ -792,18 +783,39 @@ class LLMEngine:
bos_token_id = self._get_bos_token_id()
assert bos_token_id is not None
prompt_token_ids: List[int] = [bos_token_id]
return prompt_token_ids
return [bos_token_id]
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
)
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
request_id: Optional[str] = None,
) -> LLMInputs:
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
For encoder/decoder models only:
Process an input prompt
into an `LLMInputs` instance.
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
......@@ -830,136 +842,103 @@ class LLMEngine:
Returns:
* `LLMInputs` instance
* :class:`EncoderDecoderLLMInputs` instance
'''
ptype = get_prompt_type(inputs)
# Obtain encoder and decoder prompt tokens. Note
# that, no matter what, the decoder
# prompt type is unknown.
if ptype == "ExplicitEncoderDecoder":
# If input is explicit encoder/decoder prompt,
# then it remains to be determined what type
# of encoder prompt we have
extracted_encoder_prompt = inputs.get('encoder_prompt')
encoder_ptype = None
# Extract decoder prompt from explicit
# encoder/decoder prompt
extracted_decoder_prompt = inputs.get('decoder_prompt')
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
decoder_input,
request_id=request_id,
)
else:
# If input is singleton encoder prompt, then
# we know the encoder prompt type
extracted_encoder_prompt = inputs
encoder_ptype = ptype
# Decoder prompt is always unknown if
# encoder/decoder prompt is not explicit
extracted_decoder_prompt = None
# Invoke helper function to obtain encoder
# prompt and prompt token ids, either from
# singleton encoder prompt or from the
# encoder sub-prompt of an explicit
# encoder/decode scenario 2), special
# processing is applied to the returned decoder token ids
(
encoder_prompt,
encoder_prompt_token_ids,
) = self._extract_single_prompt_for_enc_dec_input(
extracted_encoder_prompt,
request_id=request_id,
ptype=encoder_ptype,
is_encoder_prompt=True,
)
encoder_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
)
# Invoke helper method to obtain
# decoder prompt and prompt token ids.
#
# The helper method will detect the decoder
# prompt type.
#
# Helper method will also apply special
# preprocessing unique to decoder prompts.
(
decoder_prompt,
decoder_prompt_token_ids,
) = self._extract_single_prompt_for_enc_dec_input(
extracted_decoder_prompt,
request_id=request_id,
ptype=None,
is_encoder_prompt=False,
)
decoder_comps = None, None, None
return LLMInputs(
prompt_token_ids=decoder_prompt_token_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids,
encoder_prompt=encoder_prompt,
)
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
def _process_decoder_only_prompt(
self,
inputs: PromptInputs,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
request_id: Optional[str] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
'''
For decoder-only models:
Process an input prompt
into an `LLMInputs` instance.
Process an input prompt into an :class:`LLMInputs` instance.
Arguments:
* inputs: input prompt
* lora_request
* request_id
* lora_request
* prompt_adapter_request
Returns:
* `LLMInputs` instance
* :class:`LLMInputs` instance
'''
if isinstance(inputs, str):
inputs = {"prompt": inputs}
prompt = inputs.get("prompt")
if "prompt_token_ids" not in inputs:
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
prompt_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=inputs.get("multi_modal_data"))
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def process_model_inputs(
self,
request_id: str,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = self._process_encoder_decoder_prompt(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
model_inputs = self._process_decoder_only_prompt(
inputs,
......@@ -1029,10 +1008,11 @@ class LLMEngine:
arrival_time = time.time()
processed_inputs = self.process_model_inputs(
inputs,
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
request_id=request_id,
......@@ -1597,7 +1577,7 @@ class LLMEngine:
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
def is_encoder_decoder_model(self):
return is_encoder_decoder_model_config(self.model_config)
return self.model_config.is_encoder_decoder_model
def is_embedding_model(self):
return is_embedding_model_config(self.model_config)
return self.model_config.is_embedding_model
......@@ -2,8 +2,7 @@ import codecs
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
cast, final)
from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
# yapf conflicts with isort for this block
# yapf: disable
......@@ -59,7 +58,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
@final # So that it should be compatible with Dict[str, str]
# TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict):
role: str
content: str
......
......@@ -6,8 +6,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
......
......@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(logit_bias: Dict[str,
float], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
def logit_bias_logits_processor(
logit_bias: Dict[int, float],
token_ids: List[int],
logits: torch.Tensor,
) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
......
......@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeCompletionRequest,
TokenizeRequest)
# yapf: enable
from vllm.inputs import parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
......
from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
ParsedTokens, PromptInputs, SingletonPromptInputs,
TextPrompt, TokensPrompt, get_prompt_type,
is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
......@@ -14,18 +14,17 @@ See also:
"""
__all__ = [
"ParsedText",
"ParsedTokens",
"parse_and_batch_prompt",
"TextPrompt",
"TokensPrompt",
"PromptInputs",
"SingletonPromptInputs",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"INPUT_REGISTRY",
"InputContext",
"InputRegistry",
"get_prompt_type",
"is_valid_encoder_decoder_llm_inputs",
"ExplicitEncoderDecoderPrompt",
"SingletonPromptInputs",
]
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
Union)
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
......@@ -103,39 +44,49 @@ Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type SingletonPromptInputs may be employed
A prompt of type :class:`SingletonPromptInputs` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
_T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
class ExplicitEncoderDecoderPrompt(TypedDict):
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
SingletonPromptInputs schemas, and are not
:class:`SingletonPromptInputs` schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an ExplicitEncoderDecoderPrompt may not
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure may not themselves
must be SingletonPromptInputs instances.
fields of this data structure themselves must be
:class:`SingletonPromptInputs` instances.
"""
encoder_prompt: SingletonPromptInputs
encoder_prompt: _T1_co
decoder_prompt: SingletonPromptInputs
decoder_prompt: Optional[_T2_co]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
......@@ -150,60 +101,12 @@ both decoder-only and encoder/decoder input types:
"""
def _has_required_keys(
d: dict,
required_keys: set,
) -> bool:
return required_keys.issubset(d.keys())
def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
"""
Get the type-name of the prompt argument instance, given that
isinstance() cannot apply to TypedDict subclasses directly.
If the prompt is None, return 'None' as the type name.
Arguments:
* prompt: LLM input prompt or None
Returns:
* String representation of prompt type
"""
if prompt is None:
return 'None'
required_keys_dict = {
'TextPrompt': {'prompt'},
'TokensPrompt': {'prompt_token_ids'},
'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
}
if isinstance(prompt, dict):
for (ptype, required_keys) in required_keys_dict.items():
# Ignore type checking in the conditional below because type
# checker does not understand that is_dict(prompt) narrows
# down the possible types
if _has_required_keys(
prompt, # type: ignore
required_keys):
return ptype
raise ValueError(f"Invalid prompt {prompt}, valid types are "
"required_keys_dict={required_keys_dict}")
if isinstance(prompt, str):
return "str"
raise ValueError(f"Invalid prompt {prompt}")
class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
......@@ -213,7 +116,21 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""
encoder_prompt_token_ids: NotRequired[List[int]]
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class EncoderDecoderLLMInputs(LLMInputs):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder_prompt_token_ids: List[int]
"""The token IDs of the encoder prompt."""
encoder_prompt: NotRequired[Optional[str]]
......@@ -222,20 +139,40 @@ class LLMInputs(TypedDict):
available.
"""
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: Optional[_T2],
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[Optional[_T2]],
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Return True if the LLMInputs instance has the correct configuration
for encoder/decoder.
Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances.
"""
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
]
# True if encoder prompt token ids field exists &
# is not None
return ('encoder_prompt_token_ids' in inputs
and inputs['encoder_prompt_token_ids'] is not None)
def to_enc_dec_tuple_list(
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> List[Tuple[_T1, Optional[_T2]]]:
return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts]
from typing import List, Literal, Sequence, TypedDict, Union, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs)
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt, str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False) for elem in prompt
]
if is_list_of(prompt, int):
# case 3: array of tokens
return [ParsedTokens(content=prompt, is_tokens=True)]
if is_list_of(prompt, list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if is_list_of(prompt[0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in prompt
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs
def is_valid_encoder_decoder_llm_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
) -> TypeIs[EncoderDecoderLLMInputs]:
return "encoder_prompt_token_ids" in inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment