Unverified Commit 5ae5ed1e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Consolidate prompt arguments to LLM engines (#4328)


Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 290f4ada
......@@ -71,7 +71,7 @@ TEST_CHOICE = [
"Swift", "Kotlin"
]
pytestmark = pytest.mark.asyncio
pytestmark = pytest.mark.openai
@pytest.fixture(scope="session")
......@@ -91,6 +91,8 @@ def server(zephyr_lora_files):
"--max-model-len",
"8192",
"--enforce-eager",
"--gpu-memory-utilization",
"0.75",
# lora config below
"--enable-lora",
"--lora-modules",
......@@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files):
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--enforce-eager",
"--gpu-memory-utilization",
"0.75",
"--max-model-len",
"8192",
"--enforce-eager",
])
ray.get(server_runner.ready.remote())
yield server_runner
......@@ -136,6 +140,7 @@ def client():
yield client
@pytest.mark.asyncio
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
......@@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
......@@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
......@@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs.top_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
......@@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
......@@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
......@@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
......@@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
......@@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
assert texts[0] == texts[1]
@pytest.mark.asyncio
async def test_logits_bias(server, client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 5
......@@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
......@@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
......@@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
assert json1["age"] != json2["age"]
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
......@@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
......@@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
assert ip1 != ip2
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
......@@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
assert completion.choices[i].text in TEST_CHOICE
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
......@@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
assert choice1 != choice2
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
......@@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
......@@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
for token, logprob in token_dict.items())
@pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2):
resp = await client.chat.completions.create(
......@@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded
@pytest.mark.asyncio
async def test_extra_fields(server, client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create(
......@@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
assert "extra_forbidden" in exc_info.value.message
@pytest.mark.asyncio
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
......@@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
assert content == "2"
@pytest.mark.asyncio
async def test_custom_role(server, client: openai.AsyncOpenAI):
# Not sure how the model handles custom roles so we just check that
# both string and complex message content are handled in the same way
......@@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI):
assert content1 == content2
@pytest.mark.asyncio
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement
......@@ -847,6 +873,7 @@ number: "1" | "2"
assert content.strip() == ground_truth
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
......@@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
assert len(logprobs.tokens) > 5
@pytest.mark.asyncio
async def test_long_seed(server, client: openai.AsyncOpenAI):
for seed in [
torch.iinfo(torch.long).min - 1,
......@@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
......@@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 5
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
......
import multiprocessing
import sys
import time
import pytest
import torch
from openai import OpenAI, OpenAIError
......@@ -10,6 +10,8 @@ from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port
pytestmark = pytest.mark.openai
class MyOPTForCausalLM(OPTForCausalLM):
......@@ -26,15 +28,16 @@ def server_function(port):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --dtype"
f" float32 --api-key token-abc123 --port {port}").split()
("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
f"--dtype float32 --api-key token-abc123 --port {port}").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
def test_oot_registration_for_api_server():
port = get_open_port()
server = multiprocessing.Process(target=server_function, args=(port, ))
ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start()
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
......
......@@ -86,20 +86,18 @@ def generate(
def batched_generate(
llm,
llm: vllm.LLM,
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
):
for input in inputs:
prompt, sampling_param, lora_req = input
requests_data = llm._validate_and_prepare_requests(
# Add requests to the engine and run the engine
llm._validate_and_add_requests(
prompt,
sampling_param,
lora_request=lora_req,
)
# Add requests to the engine and run the engine
for request_data in requests_data:
llm._add_request(**request_data)
outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
......
......@@ -35,28 +35,25 @@ def test_logits_processor_force_generate(
# test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[0],
example_prompts[0],
params=params_with_logprobs,
prompt_token_ids=None,
)
# test prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[1],
example_prompts[1],
params=SamplingParams(
prompt_logprobs=3,
max_tokens=max_tokens,
),
prompt_token_ids=None,
)
# test grouped requests
vllm_model.model._add_request(
prompt=example_prompts[2],
example_prompts[2],
params=SamplingParams(max_tokens=max_tokens),
prompt_token_ids=None,
)
outputs = vllm_model.model._run_engine(False)
outputs = vllm_model.model._run_engine(use_tqdm=False)
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
......@@ -57,11 +57,7 @@ def test_random_sample_with_seed(
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
params=params,
)
llm._add_request(prompt, params=params)
results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]
......
......@@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
for prompt in prompts:
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
tokenizer.tokenizer.eos_token_id, lora_request)
seq = Sequence(seq_id,
inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request)
num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks):
......
from typing import List
import pytest
from vllm.inputs import parse_and_batch_prompt
STRING_INPUTS = [
'',
'foo',
'foo bar',
'foo baz bar',
'foo bar qux baz',
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
def test_parse_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([[]])
@pytest.mark.parametrize('string_input', STRING_INPUTS)
def test_parse_single_batch_string_consistent(string_input: str):
assert parse_and_batch_prompt(string_input) \
== parse_and_batch_prompt([string_input])
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: List[int]):
assert parse_and_batch_prompt(token_input) \
== parse_and_batch_prompt([token_input])
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
def test_parse_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
import pytest
from vllm.utils import deprecate_kwargs
from .utils import error_on_warning
def test_deprecate_kwargs_always():
@deprecate_kwargs("old_arg", is_deprecated=True)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_never():
@deprecate_kwargs("old_arg", is_deprecated=False)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning():
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_dynamic():
is_deprecated = True
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
is_deprecated = False
with error_on_warning():
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_additional_message():
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)
......@@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
return Sequence(
seq_id=0,
prompt="<s>",
prompt_token_ids=prompt_token_ids,
inputs={
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)
......
......@@ -2,6 +2,8 @@ import os
import subprocess
import sys
import time
import warnings
from contextlib import contextmanager
import ray
import requests
......@@ -87,3 +89,15 @@ def multi_process_tensor_parallel(
ray.get(refs)
ray.shutdown()
@contextmanager
def error_on_warning():
"""
Within the scope of this context manager, tests will fail if any warning
is emitted.
"""
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
......@@ -5,6 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
......@@ -16,6 +17,9 @@ __version__ = "0.4.2"
__all__ = [
"LLM",
"ModelRegistry",
"PromptStrictInputs",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
......
......@@ -12,12 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
......@@ -244,64 +245,69 @@ class _AsyncLLMEngine(LLMEngine):
return request_outputs
async def encode_request_async(
async def process_model_inputs_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = await tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
prompt=inputs["prompt"],
lora_request=lora_request)
return prompt_token_ids
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(request_id,
prompt=prompt,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
async def check_health_async(self) -> None:
self.model_executor.check_health()
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
"""An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMEngine to the caller.
This class is used to wrap the :class:`LLMEngine` class to make it
asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The :class:`LLMEngine` is kicked by the
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the :class:`LLMEngine` to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
......@@ -315,8 +321,8 @@ class AsyncLLMEngine:
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
*args: Arguments for :class:`LLMEngine`.
**kwargs: Arguments for :class:`LLMEngine`.
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
......@@ -526,22 +532,26 @@ class AsyncLLMEngine:
async def add_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
shortened_token_ids = prompt_token_ids
if self.max_log_len is not None:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len]
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
shortened_token_ids = shortened_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
......@@ -562,39 +572,33 @@ class AsyncLLMEngine:
arrival_time = time.time()
if self.engine_use_ray:
prompt_token_ids = await (
self.engine.encode_request_async.remote( # type: ignore
processed_inputs = await self.engine.process_model_inputs_async \
.remote( # type: ignore
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
inputs=inputs,
lora_request=lora_request)
else:
prompt_token_ids = await self.engine.encode_request_async(
processed_inputs = await self.engine.process_model_inputs_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
inputs=inputs,
lora_request=lora_request)
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
inputs=processed_inputs,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
return stream
async def generate(
self,
prompt: Optional[str],
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
......@@ -603,14 +607,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine
......@@ -659,24 +661,20 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
async for output in self._process_request(
request_id,
prompt,
inputs,
sampling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
lora_request=lora_request,
):
yield output
yield LLMEngine.validate_output(output, RequestOutput)
async def encode(
self,
prompt: Optional[str],
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
......@@ -685,14 +683,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
......@@ -739,24 +735,21 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
async for output in self._process_request(
request_id,
prompt,
inputs,
pooling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
lora_request=lora_request,
):
yield output
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def process_request(
async def _process_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
*,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
......@@ -764,12 +757,10 @@ class AsyncLLMEngine:
stream = await self.add_request(
request_id,
prompt,
inputs,
params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
try:
......
import time
from typing import Iterable, List, Optional, Type, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer
......@@ -18,6 +21,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
......@@ -25,8 +29,8 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
MultiModalData, PoolerOutput, SamplerOutput,
Sequence, SequenceGroup, SequenceGroupMetadata,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
......@@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
return {}
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
......@@ -60,11 +67,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
The :class:`~vllm.LLM` class wraps this class for offline batched inference
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs`
class. For the comprehensive list of arguments, see :ref:`engine_args`.
Args:
model_config: The configuration related to the LLM model.
......@@ -81,9 +88,60 @@ class LLMEngine:
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
usage_context: Specified entry point, used for usage info collection.
"""
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
"""A flag to toggle whether to validate the type of request output."""
@classmethod
@contextmanager
def enable_output_validation(cls):
cls.DO_VALIDATE_OUTPUT = True
yield
cls.DO_VALIDATE_OUTPUT = False
@classmethod
def validate_output(
cls,
output: object,
output_type: Type[_O],
) -> _O:
do_validate = cls.DO_VALIDATE_OUTPUT
if ((TYPE_CHECKING or do_validate)
and not isinstance(output, output_type)):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT
outputs_: List[_O]
if TYPE_CHECKING or do_validate:
outputs_ = []
for output in outputs:
if not isinstance(output, output_type):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
outputs_.append(output)
else:
outputs_ = outputs
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
def __init__(
self,
model_config: ModelConfig,
......@@ -151,12 +209,11 @@ class LLMEngine:
self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None
self.detokenizer = None
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
......@@ -318,14 +375,26 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs):
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
......@@ -335,8 +404,9 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
......@@ -346,29 +416,85 @@ class LLMEngine:
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def encode_request(
def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _add_processed_request(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
request_id: str,
processed_inputs: LLMInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def process_model_inputs(
self,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
def add_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
"""Add a request to the engine's request pool.
......@@ -378,15 +504,14 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
params: Parameters for sampling or pooling. SamplingParams
for text generation. PoolingParams for pooling.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
multi_modal_data: Multi modal data per request.
Details:
- Set arrival_time to the current time if it is None.
......@@ -417,59 +542,26 @@ class LLMEngine:
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = self.encode_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
processed_inputs = self.process_model_inputs(request_id=request_id,
inputs=inputs,
lora_request=lora_request)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
def _create_sequence_group_with_sampling(
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
......@@ -495,8 +587,7 @@ class LLMEngine:
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
lora_request=lora_request)
return seq_group
......@@ -505,9 +596,8 @@ class LLMEngine:
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
......@@ -517,7 +607,6 @@ class LLMEngine:
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
pooling_params=pooling_params)
return seq_group
......@@ -570,7 +659,7 @@ class LLMEngine:
def _process_model_outputs(
self,
output: List[Union[SamplerOutput, PoolerOutput]],
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
......@@ -585,7 +674,7 @@ class LLMEngine:
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
......
from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from vllm.sequence import SamplerOutput, SequenceGroupOutput
from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group(
sampler_outputs: List[SamplerOutput],
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[SamplerOutput]] = [
output_by_sequence_group: List[List[SequenceGroupOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs:
for step in outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
......
This diff is collapsed.
......@@ -176,9 +176,15 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e:
return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, prompt_ids,
lora_request)
result_generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
request_id,
lora_request,
)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
......
......@@ -119,12 +119,17 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids,
lora_request=lora_request))
generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
f"{request_id}-{i}",
lora_request=lora_request,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......
import time
from typing import AsyncIterator, List, Tuple
from typing import AsyncIterator, List, Optional, Tuple
from fastapi import Request
......@@ -100,11 +100,16 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt_text,
pooling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids))
generator = self.engine.encode(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
pooling_params,
f"{request_id}-{i}",
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......@@ -113,16 +118,21 @@ class OpenAIServingEmbedding(OpenAIServing):
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: EmbeddingRequestOutput = [None] * len(prompts)
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
......
......@@ -143,7 +143,8 @@ class OpenAIServing:
return json_str
async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest]
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
......@@ -155,7 +156,8 @@ class OpenAIServing:
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest]
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[LoRARequest]:
if request.model in self.served_model_names:
return None
......
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing_extensions import NotRequired
if TYPE_CHECKING:
from vllm.sequence import MultiModalData
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict):
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional["MultiModalData"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment