Unverified Commit 5ae5ed1e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Consolidate prompt arguments to LLM engines (#4328)


Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 290f4ada
...@@ -71,7 +71,7 @@ TEST_CHOICE = [ ...@@ -71,7 +71,7 @@ TEST_CHOICE = [
"Swift", "Kotlin" "Swift", "Kotlin"
] ]
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.openai
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
...@@ -91,6 +91,8 @@ def server(zephyr_lora_files): ...@@ -91,6 +91,8 @@ def server(zephyr_lora_files):
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager", "--enforce-eager",
"--gpu-memory-utilization",
"0.75",
# lora config below # lora config below
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
...@@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files): ...@@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files):
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", "bfloat16",
"--enforce-eager",
"--gpu-memory-utilization",
"0.75",
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager",
]) ])
ray.get(server_runner.ready.remote()) ray.get(server_runner.ready.remote())
yield server_runner yield server_runner
...@@ -136,6 +140,7 @@ def client(): ...@@ -136,6 +140,7 @@ def client():
yield client yield client
@pytest.mark.asyncio
async def test_check_models(server, client: openai.AsyncOpenAI): async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list() models = await client.models.list()
models = models.data models = models.data
...@@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI): ...@@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
assert lora_models[1].id == "zephyr-lora2" assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
...@@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, ...@@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5 completion.choices[0].text) >= 5
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
...@@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, ...@@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs.top_logprobs is None assert choice.logprobs.top_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
...@@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, ...@@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, ...@@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
...@@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, ...@@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
...@@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, ...@@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == output assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
...@@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, ...@@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
assert texts[0] == texts[1] assert texts[0] == texts[1]
@pytest.mark.asyncio
async def test_logits_bias(server, client: openai.AsyncOpenAI): async def test_logits_bias(server, client: openai.AsyncOpenAI):
prompt = "Hello, my name is" prompt = "Hello, my name is"
max_tokens = 5 max_tokens = 5
...@@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): ...@@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text assert first_response != completion.choices[0].text
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI, async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
...@@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI, ...@@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI, async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
...@@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI, ...@@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
assert json1["age"] != json2["age"] assert json1["age"] != json2["age"]
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
...@@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, ...@@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
...@@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, ...@@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
assert ip1 != ip2 assert ip1 != ip2
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
...@@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, ...@@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
assert completion.choices[i].text in TEST_CHOICE assert completion.choices[i].text in TEST_CHOICE
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
...@@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, ...@@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
assert choice1 != choice2 assert choice1 != choice2
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
...@@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, ...@@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
...@@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, ...@@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
for token, logprob in token_dict.items()) for token, logprob in token_dict.items())
@pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI): async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2): for _ in range(2):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
...@@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): ...@@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded assert loaded == {"result": 2}, loaded
@pytest.mark.asyncio
async def test_extra_fields(server, client: openai.AsyncOpenAI): async def test_extra_fields(server, client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info: with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create( await client.chat.completions.create(
...@@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): ...@@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
assert "extra_forbidden" in exc_info.value.message assert "extra_forbidden" in exc_info.value.message
@pytest.mark.asyncio
async def test_complex_message_content(server, client: openai.AsyncOpenAI): async def test_complex_message_content(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
...@@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): ...@@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
assert content == "2" assert content == "2"
@pytest.mark.asyncio
async def test_custom_role(server, client: openai.AsyncOpenAI): async def test_custom_role(server, client: openai.AsyncOpenAI):
# Not sure how the model handles custom roles so we just check that # Not sure how the model handles custom roles so we just check that
# both string and complex message content are handled in the same way # both string and complex message content are handled in the same way
...@@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI): ...@@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI):
assert content1 == content2 assert content1 == content2
@pytest.mark.asyncio
async def test_guided_grammar(server, client: openai.AsyncOpenAI): async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """ simple_sql_grammar = """
start: select_statement start: select_statement
...@@ -847,6 +873,7 @@ number: "1" | "2" ...@@ -847,6 +873,7 @@ number: "1" | "2"
assert content.strip() == ground_truth assert content.strip() == ground_truth
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
...@@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, ...@@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
assert len(logprobs.tokens) > 5 assert len(logprobs.tokens) > 5
@pytest.mark.asyncio
async def test_long_seed(server, client: openai.AsyncOpenAI): async def test_long_seed(server, client: openai.AsyncOpenAI):
for seed in [ for seed in [
torch.iinfo(torch.long).min - 1, torch.iinfo(torch.long).min - 1,
...@@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): ...@@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message) or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[EMBEDDING_MODEL_NAME], [EMBEDDING_MODEL_NAME],
...@@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, ...@@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 5 assert embeddings.usage.total_tokens == 5
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[EMBEDDING_MODEL_NAME], [EMBEDDING_MODEL_NAME],
......
import multiprocessing
import sys import sys
import time import time
import pytest
import torch import torch
from openai import OpenAI, OpenAIError from openai import OpenAI, OpenAIError
...@@ -10,6 +10,8 @@ from vllm.model_executor.models.opt import OPTForCausalLM ...@@ -10,6 +10,8 @@ from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port from vllm.utils import get_open_port
pytestmark = pytest.mark.openai
class MyOPTForCausalLM(OPTForCausalLM): class MyOPTForCausalLM(OPTForCausalLM):
...@@ -26,15 +28,16 @@ def server_function(port): ...@@ -26,15 +28,16 @@ def server_function(port):
# register our dummy model # register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \ sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --dtype" ("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
f" float32 --api-key token-abc123 --port {port}").split() f"--dtype float32 --api-key token-abc123 --port {port}").split()
import runpy import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
def test_oot_registration_for_api_server(): def test_oot_registration_for_api_server():
port = get_open_port() port = get_open_port()
server = multiprocessing.Process(target=server_function, args=(port, )) ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start() server.start()
client = OpenAI( client = OpenAI(
base_url=f"http://localhost:{port}/v1", base_url=f"http://localhost:{port}/v1",
......
...@@ -86,20 +86,18 @@ def generate( ...@@ -86,20 +86,18 @@ def generate(
def batched_generate( def batched_generate(
llm, llm: vllm.LLM,
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
): ):
for input in inputs: for input in inputs:
prompt, sampling_param, lora_req = input prompt, sampling_param, lora_req = input
requests_data = llm._validate_and_prepare_requests( # Add requests to the engine and run the engine
llm._validate_and_add_requests(
prompt, prompt,
sampling_param, sampling_param,
lora_request=lora_req, lora_request=lora_req,
) )
# Add requests to the engine and run the engine
for request_data in requests_data:
llm._add_request(**request_data)
outputs = llm._run_engine(use_tqdm=True) outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
......
...@@ -35,28 +35,25 @@ def test_logits_processor_force_generate( ...@@ -35,28 +35,25 @@ def test_logits_processor_force_generate(
# test logits_processors when prompt_logprobs is not None # test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[0], example_prompts[0],
params=params_with_logprobs, params=params_with_logprobs,
prompt_token_ids=None,
) )
# test prompt_logprobs is not None # test prompt_logprobs is not None
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[1], example_prompts[1],
params=SamplingParams( params=SamplingParams(
prompt_logprobs=3, prompt_logprobs=3,
max_tokens=max_tokens, max_tokens=max_tokens,
), ),
prompt_token_ids=None,
) )
# test grouped requests # test grouped requests
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[2], example_prompts[2],
params=SamplingParams(max_tokens=max_tokens), params=SamplingParams(max_tokens=max_tokens),
prompt_token_ids=None,
) )
outputs = vllm_model.model._run_engine(False) outputs = vllm_model.model._run_engine(use_tqdm=False)
assert outputs[0].outputs[0].text == enforced_answers * repeat_times assert outputs[0].outputs[0].text == enforced_answers * repeat_times
...@@ -57,11 +57,7 @@ def test_random_sample_with_seed( ...@@ -57,11 +57,7 @@ def test_random_sample_with_seed(
sampling_params_seed_1, sampling_params_seed_1,
sampling_params_seed_2, sampling_params_seed_2,
): ):
llm._add_request( llm._add_request(prompt, params=params)
prompt=prompt,
prompt_token_ids=None,
params=params,
)
results = llm._run_engine(use_tqdm=False) results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs] all_outputs = [[out.token_ids for out in output.outputs]
......
...@@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, ...@@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
for prompt in prompts: for prompt in prompts:
hashes[-1].append([]) hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, seq = Sequence(seq_id,
tokenizer.tokenizer.eos_token_id, lora_request) inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request)
num_blocks = len(prompt_token_ids) // block_size num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks): for idx in range(num_blocks):
......
from typing import List
import pytest
from vllm.inputs import parse_and_batch_prompt
STRING_INPUTS = [
'',
'foo',
'foo bar',
'foo baz bar',
'foo bar qux baz',
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
def test_parse_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([[]])
@pytest.mark.parametrize('string_input', STRING_INPUTS)
def test_parse_single_batch_string_consistent(string_input: str):
assert parse_and_batch_prompt(string_input) \
== parse_and_batch_prompt([string_input])
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: List[int]):
assert parse_and_batch_prompt(token_input) \
== parse_and_batch_prompt([token_input])
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
def test_parse_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
import pytest
from vllm.utils import deprecate_kwargs
from .utils import error_on_warning
def test_deprecate_kwargs_always():
@deprecate_kwargs("old_arg", is_deprecated=True)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_never():
@deprecate_kwargs("old_arg", is_deprecated=False)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning():
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_dynamic():
is_deprecated = True
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
is_deprecated = False
with error_on_warning():
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_additional_message():
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)
...@@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None): ...@@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1] prompt_token_ids = prompt_token_ids or [1]
return Sequence( return Sequence(
seq_id=0, seq_id=0,
prompt="<s>", inputs={
prompt_token_ids=prompt_token_ids, "prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16, block_size=16,
) )
......
...@@ -2,6 +2,8 @@ import os ...@@ -2,6 +2,8 @@ import os
import subprocess import subprocess
import sys import sys
import time import time
import warnings
from contextlib import contextmanager
import ray import ray
import requests import requests
...@@ -87,3 +89,15 @@ def multi_process_tensor_parallel( ...@@ -87,3 +89,15 @@ def multi_process_tensor_parallel(
ray.get(refs) ray.get(refs)
ray.shutdown() ray.shutdown()
@contextmanager
def error_on_warning():
"""
Within the scope of this context manager, tests will fail if any warning
is emitted.
"""
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
...@@ -5,6 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine ...@@ -5,6 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput, from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput) EmbeddingRequestOutput, RequestOutput)
...@@ -16,6 +17,9 @@ __version__ = "0.4.2" ...@@ -16,6 +17,9 @@ __version__ = "0.4.2"
__all__ = [ __all__ = [
"LLM", "LLM",
"ModelRegistry", "ModelRegistry",
"PromptStrictInputs",
"TextPrompt",
"TokensPrompt",
"SamplingParams", "SamplingParams",
"RequestOutput", "RequestOutput",
"CompletionOutput", "CompletionOutput",
......
...@@ -12,12 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs ...@@ -12,12 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -244,64 +245,69 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -244,64 +245,69 @@ class _AsyncLLMEngine(LLMEngine):
return request_outputs return request_outputs
async def encode_request_async( async def process_model_inputs_async(
self, self,
request_id: str, # pylint: disable=unused-argument request_id: str,
prompt: Optional[str], inputs: PromptInputs,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
): ) -> LLMInputs:
if prompt_token_ids is None: if isinstance(inputs, str):
assert prompt is not None inputs = {"prompt": inputs}
prompt_token_ids = await self.tokenizer.encode_async(
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = await tokenizer.encode_async(
request_id=request_id, request_id=request_id,
prompt=prompt, prompt=inputs["prompt"],
lora_request=lora_request) lora_request=lora_request)
return prompt_token_ids else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request)
self._add_processed_request(
request_id=request_id, request_id=request_id,
prompt=prompt, processed_inputs=processed_inputs,
prompt_token_ids=prompt_token_ids, params=params,
lora_request=lora_request) arrival_time=arrival_time,
lora_request=lora_request,
return self.add_request(request_id, )
prompt=prompt,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
self.model_executor.check_health() self.model_executor.check_health()
class AsyncLLMEngine: class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine. """An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the LLMEngine class to make it asynchronous. It This class is used to wrap the :class:`LLMEngine` class to make it
uses asyncio to create a background loop that keeps processing incoming asynchronous. It uses asyncio to create a background loop that keeps
requests. The LLMEngine is kicked by the generate method when there processing incoming requests. The :class:`LLMEngine` is kicked by the
are requests in the waiting queue. The generate method yields the outputs generate method when there are requests in the waiting queue. The generate
from the LLMEngine to the caller. method yields the outputs from the :class:`LLMEngine` to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`. NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
Args: Args:
worker_use_ray: Whether to use Ray for model workers. Required for worker_use_ray: Whether to use Ray for model workers. Required for
...@@ -315,8 +321,8 @@ class AsyncLLMEngine: ...@@ -315,8 +321,8 @@ class AsyncLLMEngine:
being printed in log. being printed in log.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
*args: Arguments for LLMEngine. *args: Arguments for :class:`LLMEngine`.
*kwargs: Arguments for LLMEngine. **kwargs: Arguments for :class:`LLMEngine`.
""" """
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
...@@ -526,22 +532,26 @@ class AsyncLLMEngine: ...@@ -526,22 +532,26 @@ class AsyncLLMEngine:
async def add_request( async def add_request(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
shortened_prompt = prompt if isinstance(inputs, str):
shortened_token_ids = prompt_token_ids shortened_prompt = inputs
if self.max_log_len is not None: shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None: if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len] shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None: if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self. shortened_token_ids = shortened_token_ids[:max_log_len]
max_log_len]
logger.info( logger.info(
"Received request %s: prompt: %r, " "Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, " "params: %s, prompt_token_ids: %s, "
...@@ -562,39 +572,33 @@ class AsyncLLMEngine: ...@@ -562,39 +572,33 @@ class AsyncLLMEngine:
arrival_time = time.time() arrival_time = time.time()
if self.engine_use_ray: if self.engine_use_ray:
prompt_token_ids = await ( processed_inputs = await self.engine.process_model_inputs_async \
self.engine.encode_request_async.remote( # type: ignore .remote( # type: ignore
request_id=request_id, request_id=request_id,
prompt=prompt, inputs=inputs,
prompt_token_ids=prompt_token_ids, lora_request=lora_request)
lora_request=lora_request))
else: else:
prompt_token_ids = await self.engine.encode_request_async( processed_inputs = await self.engine.process_model_inputs_async(
request_id=request_id, request_id=request_id,
prompt=prompt, inputs=inputs,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request) lora_request=lora_request)
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
prompt=prompt, inputs=processed_inputs,
params=params, params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data,
) )
return stream return stream
async def generate( async def generate(
self, self,
prompt: Optional[str], inputs: PromptInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -603,14 +607,12 @@ class AsyncLLMEngine: ...@@ -603,14 +607,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt string. Can be None if prompt_token_ids is inputs: The inputs to the LLM. See
provided. :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine The output `RequestOutput` objects from the LLMEngine
...@@ -659,24 +661,20 @@ class AsyncLLMEngine: ...@@ -659,24 +661,20 @@ class AsyncLLMEngine:
>>> # Process and return the final output >>> # Process and return the final output
>>> ... >>> ...
""" """
async for output in self.process_request( async for output in self._process_request(
request_id, request_id,
prompt, inputs,
sampling_params, sampling_params,
prompt_token_ids, lora_request=lora_request,
lora_request,
multi_modal_data,
): ):
yield output yield LLMEngine.validate_output(output, RequestOutput)
async def encode( async def encode(
self, self,
prompt: Optional[str], inputs: PromptInputs,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[EmbeddingRequestOutput]: ) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
...@@ -685,14 +683,12 @@ class AsyncLLMEngine: ...@@ -685,14 +683,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt string. Can be None if prompt_token_ids is inputs: The inputs to the LLM. See
provided. :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
...@@ -739,24 +735,21 @@ class AsyncLLMEngine: ...@@ -739,24 +735,21 @@ class AsyncLLMEngine:
>>> # Process and return the final output >>> # Process and return the final output
>>> ... >>> ...
""" """
async for output in self.process_request( async for output in self._process_request(
request_id, request_id,
prompt, inputs,
pooling_params, pooling_params,
prompt_token_ids, lora_request=lora_request,
lora_request,
multi_modal_data,
): ):
yield output yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def process_request( async def _process_request(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None, *,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or """Common logic to process requests with SamplingParams or
PoolingParams.""" PoolingParams."""
...@@ -764,12 +757,10 @@ class AsyncLLMEngine: ...@@ -764,12 +757,10 @@ class AsyncLLMEngine:
stream = await self.add_request( stream = await self.add_request(
request_id, request_id,
prompt, inputs,
params, params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data,
) )
try: try:
......
import time import time
from typing import Iterable, List, Optional, Type, Union from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer from transformers import GenerationConfig, PreTrainedTokenizer
...@@ -18,6 +21,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker ...@@ -18,6 +21,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
...@@ -25,8 +29,8 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, ...@@ -25,8 +29,8 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
MultiModalData, PoolerOutput, SamplerOutput, PoolerOutput, SamplerOutput, Sequence,
Sequence, SequenceGroup, SequenceGroupMetadata, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
...@@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig): ...@@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
return {} return {}
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
...@@ -60,11 +67,11 @@ class LLMEngine: ...@@ -60,11 +67,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the iteration-level scheduling and efficient memory management to maximize the
serving throughput. serving throughput.
The `LLM` class wraps this class for offline batched inference and the The :class:`~vllm.LLM` class wraps this class for offline batched inference
`AsyncLLMEngine` class wraps this class for online serving. and the :class:`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs`
comprehensive list of arguments, see `EngineArgs`. class. For the comprehensive list of arguments, see :ref:`engine_args`.
Args: Args:
model_config: The configuration related to the LLM model. model_config: The configuration related to the LLM model.
...@@ -81,9 +88,60 @@ class LLMEngine: ...@@ -81,9 +88,60 @@ class LLMEngine:
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection usage_context: Specified entry point, used for usage info collection.
""" """
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
"""A flag to toggle whether to validate the type of request output."""
@classmethod
@contextmanager
def enable_output_validation(cls):
cls.DO_VALIDATE_OUTPUT = True
yield
cls.DO_VALIDATE_OUTPUT = False
@classmethod
def validate_output(
cls,
output: object,
output_type: Type[_O],
) -> _O:
do_validate = cls.DO_VALIDATE_OUTPUT
if ((TYPE_CHECKING or do_validate)
and not isinstance(output, output_type)):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT
outputs_: List[_O]
if TYPE_CHECKING or do_validate:
outputs_ = []
for output in outputs:
if not isinstance(output, output_type):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
outputs_.append(output)
else:
outputs_ = outputs
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -151,12 +209,11 @@ class LLMEngine: ...@@ -151,12 +209,11 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup self.tokenizer = self._init_tokenizer()
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
else: else:
self.detokenizer = None
self.tokenizer = None self.tokenizer = None
self.detokenizer = None
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = _load_generation_config_dict(
...@@ -318,14 +375,26 @@ class LLMEngine: ...@@ -318,14 +375,26 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None): if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown() model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer": def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None) return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self, def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer": sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request) return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs): def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict( init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer, tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config), enable_lora=bool(self.lora_config),
...@@ -335,8 +404,9 @@ class LLMEngine: ...@@ -335,8 +404,9 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision) revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs) init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs) return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
...@@ -346,29 +416,85 @@ class LLMEngine: ...@@ -346,29 +416,85 @@ class LLMEngine:
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
def encode_request( def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _add_processed_request(
self, self,
request_id: str, # pylint: disable=unused-argument request_id: str,
prompt: Optional[str], processed_inputs: LLMInputs,
prompt_token_ids: Optional[List[int]] = None, params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def process_model_inputs(
self,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
): ) -> LLMInputs:
if prompt_token_ids is None: if isinstance(inputs, str):
assert prompt is not None inputs = {"prompt": inputs}
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt, if "prompt_token_ids" not in inputs:
lora_request=lora_request) tokenizer = self.get_tokenizer_group("prompts must be None if "
return prompt_token_ids "skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -378,15 +504,14 @@ class LLMEngine: ...@@ -378,15 +504,14 @@ class LLMEngine:
Args: Args:
request_id: The unique ID of the request. request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is inputs: The inputs to the LLM. See
provided. :class:`~vllm.inputs.PromptInputs`
params: Parameters for sampling or pooling. SamplingParams for more details about the format of each input.
for text generation. PoolingParams for pooling. params: Parameters for sampling or pooling.
prompt_token_ids: The token IDs of the prompt. If None, we :class:`~vllm.SamplingParams` for text generation.
use the tokenizer to convert the prompts to token IDs. :class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
multi_modal_data: Multi modal data per request.
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
...@@ -417,59 +542,26 @@ class LLMEngine: ...@@ -417,59 +542,26 @@ class LLMEngine:
"not enabled!") "not enabled!")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
prompt_token_ids = self.encode_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams processed_inputs = self.process_model_inputs(request_id=request_id,
if isinstance(params, SamplingParams): inputs=inputs,
seq_group = self._create_sequence_group_with_sampling( lora_request=lora_request)
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler. self._add_processed_request(
self.scheduler.add_seq_group(seq_group) request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
def _create_sequence_group_with_sampling( def _create_sequence_group_with_sampling(
self, self,
request_id: str, request_id: str,
seq: Sequence, seq: Sequence,
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: Optional[float] = None, arrival_time: float,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest],
multi_modal_data: Optional[MultiModalData] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
...@@ -495,8 +587,7 @@ class LLMEngine: ...@@ -495,8 +587,7 @@ class LLMEngine:
seqs=[seq], seqs=[seq],
arrival_time=arrival_time, arrival_time=arrival_time,
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request)
multi_modal_data=multi_modal_data)
return seq_group return seq_group
...@@ -505,9 +596,8 @@ class LLMEngine: ...@@ -505,9 +596,8 @@ class LLMEngine:
request_id: str, request_id: str,
seq: Sequence, seq: Sequence,
pooling_params: PoolingParams, pooling_params: PoolingParams,
arrival_time: Optional[float] = None, arrival_time: float,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest],
multi_modal_data: Optional[MultiModalData] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams.""" """Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
...@@ -517,7 +607,6 @@ class LLMEngine: ...@@ -517,7 +607,6 @@ class LLMEngine:
seqs=[seq], seqs=[seq],
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data,
pooling_params=pooling_params) pooling_params=pooling_params)
return seq_group return seq_group
...@@ -570,7 +659,7 @@ class LLMEngine: ...@@ -570,7 +659,7 @@ class LLMEngine:
def _process_model_outputs( def _process_model_outputs(
self, self,
output: List[Union[SamplerOutput, PoolerOutput]], output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup], scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
...@@ -585,7 +674,7 @@ class LLMEngine: ...@@ -585,7 +674,7 @@ class LLMEngine:
# Organize outputs by [sequence group][step] instead of # Organize outputs by [sequence group][step] instead of
# [step][sequence group]. # [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group( output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip( for scheduled_seq_group, outputs, seq_group_meta in zip(
......
from typing import List from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from vllm.sequence import SamplerOutput, SequenceGroupOutput from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group( def create_output_by_sequence_group(
sampler_outputs: List[SamplerOutput], outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]: num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by """Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step]. [step][sequence group] into [sequence group][step].
""" """
output_by_sequence_group: List[List[SamplerOutput]] = [ output_by_sequence_group: List[List[SequenceGroupOutput]] = [
[] for _ in range(num_seq_groups) [] for _ in range(num_seq_groups)
] ]
for step in sampler_outputs: for step in outputs:
for i, sequence_group_output in enumerate(step): for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output) output_by_sequence_group[i].append(sequence_group_output)
......
from typing import List, Optional, Union from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
TextTokensPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
...@@ -13,7 +16,7 @@ from vllm.pooling_params import PoolingParams ...@@ -13,7 +16,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter from vllm.utils import Counter, deprecate_kwargs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,8 +31,10 @@ class LLM: ...@@ -28,8 +31,10 @@ class LLM:
mechanism and efficient memory management. mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead. serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
NOTE: For the comprehensive list of arguments, see
:class:`~vllm.EngineArgs`.
Args: Args:
model: The name or path of a HuggingFace Transformers model. model: The name or path of a HuggingFace Transformers model.
...@@ -81,6 +86,18 @@ class LLM: ...@@ -81,6 +86,18 @@ class LLM:
disable_custom_all_reduce: See ParallelConfig disable_custom_all_reduce: See ParallelConfig
""" """
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
cls.DEPRECATE_LEGACY = True
yield
cls.DEPRECATE_LEGACY = False
def __init__( def __init__(
self, self,
model: str, model: str,
...@@ -138,15 +155,101 @@ class LLM: ...@@ -138,15 +155,101 @@ class LLM:
) -> None: ) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer self.llm_engine.tokenizer.tokenizer = tokenizer
@overload # LEGACY: single (prompt + optional token ids)
def generate(
self,
prompts: str,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def generate( def generate(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: List[str],
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None, List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def generate(
self,
prompts: Optional[str] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def generate(
self,
prompts: Optional[List[str]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def generate(
self,
prompts: None,
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload
def generate(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def generate(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -155,49 +258,138 @@ class LLM: ...@@ -155,49 +258,138 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
prompts: A list of prompts to generate completions for. inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt. When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt. prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns: Returns:
A list of `RequestOutput` objects containing the A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts. generated completions in the same order as the input prompts.
""" """
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
requests_data = self._validate_and_prepare_requests( self._validate_and_add_requests(
prompts, inputs=inputs,
sampling_params, params=sampling_params,
prompt_token_ids, lora_request=lora_request,
lora_request,
multi_modal_data,
) )
# Add requests to the engine and run the engine outputs = self._run_engine(use_tqdm=use_tqdm)
for request_data in requests_data: return LLMEngine.validate_outputs(outputs, RequestOutput)
self._add_request(**request_data)
return self._run_engine(use_tqdm) @overload # LEGACY: single (prompt + optional token ids)
def encode(
self,
prompts: str,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def encode( def encode(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: List[str],
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
List[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def encode(
self,
prompts: Optional[str] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def encode(
self,
prompts: Optional[List[str]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def encode(
self,
prompts: None,
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload
def encode(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[EmbeddingRequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def encode(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -206,124 +398,133 @@ class LLM: ...@@ -206,124 +398,133 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
prompts: A list of prompts to generate completions for. inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters. use the default pooling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns: Returns:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts. generated embeddings in the same order as the input prompts.
""" """
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
requests_data = self._validate_and_prepare_requests( self._validate_and_add_requests(
prompts, inputs=inputs,
pooling_params, params=pooling_params,
prompt_token_ids, lora_request=lora_request,
lora_request,
multi_modal_data,
) )
# Add requests to the engine and run the engine outputs = self._run_engine(use_tqdm=use_tqdm)
for request_data in requests_data: return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
self._add_request(**request_data)
return self._run_engine(use_tqdm) # LEGACY
def _convert_v1_inputs(
def _validate_and_prepare_requests(
self, self,
prompts: Optional[Union[str, List[str]]], prompts: Optional[Union[str, List[str]]],
params: Union[Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
List[Union[SamplingParams, multi_modal_data: Optional[MultiModalData],
PoolingParams]]], # Unified parameter ):
prompt_token_ids: Optional[List[List[int]]] = None, # skip_tokenizer_init is now checked in engine
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[dict]:
"""Validates and prepares request data for adding to the engine.
Ensures prompts and token IDs are consistent, and returns a list of if prompts is not None:
dictionaries with request data for further processing. prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
""" if prompt_token_ids is not None:
if prompts is None and prompt_token_ids is None: prompt_token_ids = [
raise ValueError("Either prompts or prompt_token_ids must be " p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
"provided.") ]
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (prompts is not None and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
num_requests = None
if prompts is not None: if prompts is not None:
num_requests = len(prompts) num_requests = len(prompts)
else: if prompt_token_ids is not None:
assert prompt_token_ids is not None if (num_requests is not None
and num_requests != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
num_requests = len(prompt_token_ids) num_requests = len(prompt_token_ids)
if num_requests is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
inputs: List[PromptInputs] = []
for i in range(num_requests):
if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i])
else:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
if multi_modal_data is not None:
item["multi_modal_data"] = multi_modal_data
inputs.append(item)
return inputs
def _validate_and_add_requests(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[LoRARequest],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
inputs = [inputs]
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests: if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params " raise ValueError("The lengths of prompts and params "
"must be the same.") "must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
requests_data = [] for i, request_inputs in enumerate(inputs):
for i in range(num_requests): self._add_request(
prompt = prompts[i] if prompts is not None else None request_inputs,
token_ids = None if prompt_token_ids is None else prompt_token_ids[ params[i] if isinstance(params, Sequence) else params,
i] lora_request=lora_request,
)
multi_modal_item = MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0),
) if multi_modal_data else None
requests_data.append({
"prompt":
prompt,
"params":
params[i] if isinstance(params, list) else params,
"prompt_token_ids":
token_ids,
"lora_request":
lora_request,
"multi_modal_data":
multi_modal_item,
})
return requests_data
def _add_request( def _add_request(
self, self,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, self.llm_engine.add_request(request_id,
prompt, inputs,
params, params,
prompt_token_ids, lora_request=lora_request)
lora_request=lora_request,
multi_modal_data=multi_modal_data)
def _run_engine( def _run_engine(
self, use_tqdm: bool self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
...@@ -355,5 +556,4 @@ class LLM: ...@@ -355,5 +556,4 @@ class LLM:
# Sort the outputs by request ID. # Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id)) return sorted(outputs, key=lambda x: int(x.request_id))
return outputs
...@@ -176,9 +176,15 @@ class OpenAIServingChat(OpenAIServing): ...@@ -176,9 +176,15 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt_text, sampling_params, result_generator = self.engine.generate(
request_id, prompt_ids, {
lora_request) "prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
request_id,
lora_request,
)
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
......
...@@ -119,12 +119,17 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -119,12 +119,17 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens) truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats prompt_ids, prompt_text = prompt_formats
generators.append( generator = self.engine.generate(
self.engine.generate(prompt_text, {
sampling_params, "prompt": prompt_text,
f"{request_id}-{i}", "prompt_token_ids": prompt_ids
prompt_token_ids=prompt_ids, },
lora_request=lora_request)) sampling_params,
f"{request_id}-{i}",
lora_request=lora_request,
)
generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
......
import time import time
from typing import AsyncIterator, List, Tuple from typing import AsyncIterator, List, Optional, Tuple
from fastapi import Request from fastapi import Request
...@@ -100,11 +100,16 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -100,11 +100,16 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_ids, prompt_text = prompt_formats prompt_ids, prompt_text = prompt_formats
generators.append( generator = self.engine.encode(
self.engine.generate(prompt_text, {
pooling_params, "prompt": prompt_text,
f"{request_id}-{i}", "prompt_token_ids": prompt_ids
prompt_token_ids=prompt_ids)) },
pooling_params,
f"{request_id}-{i}",
)
generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -113,16 +118,21 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -113,16 +118,21 @@ class OpenAIServingEmbedding(OpenAIServing):
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
# Non-streaming response # Non-streaming response
final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) final_res_batch: List[Optional[EmbeddingRequestOutput]]
async for i, res in result_generator: final_res_batch = [None] * len(prompts)
if await raw_request.is_disconnected(): try:
# Abort the request if the client disconnects. async for i, res in result_generator:
await self.engine.abort(f"{request_id}-{i}") if await raw_request.is_disconnected():
# TODO: Use a vllm-specific Validation Error # Abort the request if the client disconnects.
return self.create_error_response("Client disconnected") await self.engine.abort(f"{request_id}-{i}")
final_res_batch[i] = res # TODO: Use a vllm-specific Validation Error
response = request_output_to_embedding_response( return self.create_error_response("Client disconnected")
final_res_batch, request_id, created_time, model_name) final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response return response
......
...@@ -143,7 +143,8 @@ class OpenAIServing: ...@@ -143,7 +143,8 @@ class OpenAIServing:
return json_str return json_str
async def _check_model( async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest] self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
...@@ -155,7 +156,8 @@ class OpenAIServing: ...@@ -155,7 +156,8 @@ class OpenAIServing:
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora( def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest] self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[LoRARequest]: ) -> Optional[LoRARequest]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
......
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing_extensions import NotRequired
if TYPE_CHECKING:
from vllm.sequence import MultiModalData
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict):
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional["MultiModalData"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment