"examples/offline_inference/encoder_decoder.py" did not exist on "fd95e026e0f9f50bacf1a63ef419df8bacfc99c0"
Unverified Commit e19bce40 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove AsyncLLMEngine (#25025)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 505805b6
...@@ -28,11 +28,9 @@ def monkeypatch_module(): ...@@ -28,11 +28,9 @@ def monkeypatch_module():
mpatch.undo() mpatch.undo()
@pytest.fixture(scope="module", params=[False, True]) @pytest.fixture(scope="module")
def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811 def server(monkeypatch_module, zephyr_lora_files): #noqa: F811
monkeypatch_module.setenv('VLLM_USE_V1', '1')
use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
...@@ -57,13 +55,6 @@ def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811 ...@@ -57,13 +55,6 @@ def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811
yield remote_server yield remote_server
@pytest.fixture
def is_v1_server(server):
import os
assert os.environ['VLLM_USE_V1'] in ['0', '1']
return os.environ['VLLM_USE_V1'] == '1'
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def client(server): async def client(server):
async with server.get_async_client() as async_client: async with server.get_async_client() as async_client:
...@@ -481,10 +472,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -481,10 +472,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_structured_outputs_choice_chat( async def test_structured_outputs_choice_chat(
client: openai.AsyncOpenAI, sample_structured_outputs_choices, client: openai.AsyncOpenAI,
is_v1_server: bool): sample_structured_outputs_choices,
if not is_v1_server: ):
pytest.skip("Structured outputs is only supported in v1 engine")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -522,12 +512,10 @@ async def test_structured_outputs_choice_chat( ...@@ -522,12 +512,10 @@ async def test_structured_outputs_choice_chat(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_structured_outputs_json_chat(client: openai.AsyncOpenAI, async def test_structured_outputs_json_chat(
sample_json_schema, client: openai.AsyncOpenAI,
is_v1_server: bool): sample_json_schema,
if not is_v1_server: ):
pytest.skip("Structured outputs is only supported in v1 engine")
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -569,10 +557,10 @@ async def test_structured_outputs_json_chat(client: openai.AsyncOpenAI, ...@@ -569,10 +557,10 @@ async def test_structured_outputs_json_chat(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_structured_outputs_regex_chat(client: openai.AsyncOpenAI, async def test_structured_outputs_regex_chat(
sample_regex, is_v1_server: bool): client: openai.AsyncOpenAI,
if not is_v1_server: sample_regex,
pytest.skip("Structured outputs is only supported in v1 engine") ):
messages = [{ messages = [{
"role": "system", "role": "system",
...@@ -660,10 +648,10 @@ async def test_structured_outputs_choice_chat_logprobs( ...@@ -660,10 +648,10 @@ async def test_structured_outputs_choice_chat_logprobs(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, async def test_named_tool_use(
is_v1_server: bool): client: openai.AsyncOpenAI,
if not is_v1_server: sample_json_schema,
pytest.skip("Tool use is only supported in v1 engine") ):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
...@@ -821,11 +809,7 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): ...@@ -821,11 +809,7 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI, async def test_response_format_json_schema(client: openai.AsyncOpenAI):
is_v1_server: bool):
if not is_v1_server:
pytest.skip(
"JSON schema response format is only supported in v1 engine")
prompt = 'what is 1+1? The format is "result": 2' prompt = 'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema # Check that this prompt cannot lead to a valid JSON without json_schema
for _ in range(2): for _ in range(2):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for structured outputs tests
import json
import os
from typing import Optional
import jsonschema
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import regex as re
import requests
# downloading lora to test lora requests
from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically these adapters use a different base model,
# but we're not testing generation quality here
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enforce-eager",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
]
@pytest.fixture(scope="module",
params=["", "--disable-frontend-multiprocessing"])
def server(default_server_args, request):
if request.param:
default_server_args.append(request.param)
original_value = os.environ.get('VLLM_USE_V1')
os.environ['VLLM_USE_V1'] = '0'
try:
with RemoteOpenAIServer(MODEL_NAME,
default_server_args) as remote_server:
yield remote_server
finally:
# Restore original env value
if original_value is None:
os.environ.pop('VLLM_USE_V1', None)
else:
os.environ['VLLM_USE_V1'] = original_value
@pytest.fixture
def is_v1_server(server):
import os
# For completion tests, we assume v0 since there's no explicit v1 setup
return os.environ.get('VLLM_USE_V1', '0') == '1'
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None
@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
with pytest.raises(openai.BadRequestError, match="out of vocabulary"):
# Added tokens should be rejected by the base model
await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=None,
)
choice = completion.choices[0]
assert choice.logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=0,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=5,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str):
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=21,
)
...
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=30,
stream=True,
)
async for chunk in stream:
...
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0),
(MODEL_NAME, 1),
(MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: Optional[int]):
params: dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.completions.create(**params)
else:
completion = await client.completions.create(**params)
if prompt_logprobs is not None:
assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0
assert completion.choices[1].prompt_logprobs is not None
assert len(completion.choices[1].prompt_logprobs) > 0
else:
assert completion.choices[0].prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is an LLM?"
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt = "What is an LLM?"
n = 3
max_tokens = 5
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=True)
chunks: list[list[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == n
for chunk in chunks:
assert len(chunk) == max_tokens
print("".join(chunk))
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
False,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
False,
})
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
else:
assert chunk.usage is None
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is not None
assert chunk.usage.prompt_tokens > 0
assert chunk.usage.completion_tokens > 0
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
chunk.usage.completion_tokens)
if chunk.choices[0].finish_reason is not None:
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=False, stream_options=
# {"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None})
# Test stream=False, stream_options=
# {"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True})
# Test stream=False, stream_options=
# {"continuous_usage_stats": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": None})
# Test stream=False, stream_options=
# {"continuous_usage_stats": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": True})
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but
# not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
@pytest.mark.asyncio
async def test_logits_bias(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 5
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Test exclusive selection
token_id = 1000
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
seed=42,
)
assert len(completion.choices[0].text) >= 5
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
add_special_tokens=False)["input_ids"]
assert all([
response == expected
for response, expected in zip(response_tokens, expected_tokens)
])
# Test ban
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
)
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
first_response = completion.choices[0].text
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token): -100
for token in response_tokens},
)
assert first_response != completion.choices[0].text
@pytest.mark.asyncio
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 1
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Test exclusive selection
allowed_ids = [21555, 21557, 21558]
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
seed=42,
extra_body=dict(allowed_token_ids=allowed_ids),
logprobs=1,
)
response_tokens = completion.choices[0].logprobs.tokens
assert len(response_tokens) == 1
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
@pytest.mark.asyncio
async def test_structured_outputs_json_completion(
client: openai.AsyncOpenAI,
sample_json_schema,
is_v1_server: bool,
):
if not is_v1_server:
pytest.skip("structured outputs is only supported in v1 engine")
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}",
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(structured_outputs=dict(json=sample_json_schema)))
assert completion.id is not None
assert len(completion.choices) == 3
for i in range(3):
output_json = json.loads(completion.choices[i].text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.asyncio
async def test_structured_outputs_regex_completion(
client: openai.AsyncOpenAI,
sample_regex,
is_v1_server: bool,
):
if not is_v1_server:
pytest.skip("structured outputs is only supported in v1 engine")
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(structured_outputs=dict(regex=sample_regex)))
assert completion.id is not None
assert len(completion.choices) == 3
for i in range(3):
assert re.fullmatch(sample_regex,
completion.choices[i].text) is not None
@pytest.mark.asyncio
async def test_structured_outputs_choice_completion(
client: openai.AsyncOpenAI,
sample_structured_outputs_choices,
is_v1_server: bool,
):
if not is_v1_server:
pytest.skip("structured outputs is only supported in v1 engine")
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(structured_outputs=dict(
choice=sample_structured_outputs_choices)))
assert completion.id is not None
assert len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in sample_structured_outputs_choices
@pytest.mark.asyncio
async def test_structured_outputs_grammar(client: openai.AsyncOpenAI,
sample_sql_statements,
is_v1_server: bool):
if not is_v1_server:
pytest.skip("grammar is only supported in v1 engine")
completion = await client.completions.create(
model=MODEL_NAME,
prompt=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
temperature=1.0,
max_tokens=500,
extra_body=dict(
structured_outputs=dict(grammar=sample_sql_statements), ))
content = completion.choices[0].text
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(content)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
assert content.strip() == ground_truth
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
@pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
model_name: str, logprobs_arg: int):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=logprobs_arg)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert re.search(r"^" + prompt_text, completion.choices[0].text)
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5
@pytest.mark.asyncio
async def test_structured_outputs_type_error(client: openai.AsyncOpenAI,
sample_json_schema, sample_regex,
is_v1_server: bool):
if not is_v1_server:
pytest.skip("structured outputs is only supported in v1 engine")
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(structured_outputs=dict(json=42)))
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example string that fits this regex",
extra_body=dict(structured_outputs=dict(
regex=sample_regex,
json=sample_json_schema,
)))
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,stream,echo",
[
(MODEL_NAME, False, False),
(MODEL_NAME, False, True),
(MODEL_NAME, True, False),
(MODEL_NAME, True, True) # should not raise BadRequestError error
],
)
async def test_echo_stream_completion(client: openai.AsyncOpenAI,
model_name: str, stream: bool,
echo: bool):
saying: str = "Hello, my name is"
result = await client.completions.create(model=model_name,
prompt=saying,
max_tokens=10,
temperature=0.0,
echo=echo,
stream=stream)
stop_reason = "length"
if not stream:
completion = result
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == stop_reason
if echo:
assert choice.text is not None and saying in choice.text
else:
assert choice.text is not None and saying not in choice.text
else:
chunks: list[str] = []
final_finish_reason = None
async for chunk in result:
if chunk.choices and chunk.choices[0].text:
chunks.append(chunk.choices[0].text)
if chunk.choices and chunk.choices[0].finish_reason:
final_finish_reason = chunk.choices[0].finish_reason
assert final_finish_reason == stop_reason
content = "".join(chunks)
if echo:
assert content is not None and saying in content
else:
assert content is not None and saying not in content
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):
request_args = {
"model": MODEL_NAME,
"prompt": "Hello, my name is",
"max_tokens": 5,
"temperature": 0.0,
"logprobs": None,
}
completion = await client.completions.create(**request_args)
invocation_response = requests.post(server.url_for("invocations"),
json=request_args)
invocation_response.raise_for_status()
completion_output = completion.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
assert completion_output["choices"] == invocation_output["choices"]
...@@ -14,6 +14,9 @@ from transformers import AutoConfig ...@@ -14,6 +14,9 @@ from transformers import AutoConfig
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
pytest.skip("Skipping prompt_embeds test until V1 supports it.",
allow_module_level=True)
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
......
...@@ -53,12 +53,13 @@ def monkeypatch_module(): ...@@ -53,12 +53,13 @@ def monkeypatch_module():
mpatch.undo() mpatch.undo()
@pytest.fixture(scope="module", params=[False, True]) @pytest.fixture(scope="module", params=[True])
def server_with_lora_modules_json(request, monkeypatch_module, def server_with_lora_modules_json(request, monkeypatch_module,
zephyr_lora_files): zephyr_lora_files):
use_v1 = request.param use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') assert use_v1
monkeypatch_module.setenv('VLLM_USE_V1', '1')
# Define the json format LoRA module configurations # Define the json format LoRA module configurations
lora_module_1 = { lora_module_1 = {
......
...@@ -22,7 +22,7 @@ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ...@@ -22,7 +22,7 @@ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
PREV_MINOR_VERSION = version._prev_minor_version() PREV_MINOR_VERSION = version._prev_minor_version()
@pytest.fixture(scope="module", params=[True, False]) @pytest.fixture(scope="module", params=[True])
def use_v1(request): def use_v1(request):
# Module-scoped variant of run_with_both_engines # Module-scoped variant of run_with_both_engines
# #
......
...@@ -10,8 +10,30 @@ import pytest ...@@ -10,8 +10,30 @@ import pytest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401
from .test_completion import MODEL_NAME MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enforce-eager",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
......
...@@ -15,14 +15,6 @@ MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" ...@@ -15,14 +15,6 @@ MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16" DTYPE = "float16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
......
...@@ -7,7 +7,6 @@ import pytest ...@@ -7,7 +7,6 @@ import pytest
import vllm.envs as envs import vllm.envs as envs
from vllm import LLM from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
MODEL = "meta-llama/Llama-3.2-1B-Instruct" MODEL = "meta-llama/Llama-3.2-1B-Instruct"
...@@ -96,20 +95,3 @@ def test_v1_attn_backend(monkeypatch): ...@@ -96,20 +95,3 @@ def test_v1_attn_backend(monkeypatch):
_ = AsyncEngineArgs(model=MODEL).create_engine_config() _ = AsyncEngineArgs(model=MODEL).create_engine_config()
assert envs.VLLM_USE_V1 assert envs.VLLM_USE_V1
m.delenv("VLLM_USE_V1") m.delenv("VLLM_USE_V1")
def test_reject_using_constructor_directly(monkeypatch):
with monkeypatch.context() as m:
if os.getenv("VLLM_USE_V1", None):
m.delenv("VLLM_USE_V1")
# Sets VLLM_USE_V1=1.
vllm_config = AsyncEngineArgs(model=MODEL).create_engine_config()
# This uses the V0 constructor directly.
with pytest.raises(ValueError):
AsyncLLMEngine(vllm_config,
AsyncLLMEngine._get_executor_cls(vllm_config),
log_stats=True)
m.delenv("VLLM_USE_V1")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio from vllm.v1.engine.async_llm import AsyncLLM
import time
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from weakref import ReferenceType
import vllm.envs as envs AsyncLLMEngine = AsyncLLM # type: ignore
from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, deprecate_kwargs, weak_bind
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
class AsyncEngineDeadError(RuntimeError):
pass
def _log_task_completion(task: asyncio.Task,
error_callback: Callable[[Exception], None]) -> None:
"""This function is only intended for the `engine.run_engine_loop()` task.
In particular, that task runs a `while True` loop that can only exit if
there is an exception.
"""
exception = None
try:
return_value = task.result()
raise AssertionError(
f"The engine background task should never finish without an "
f"exception. {return_value}")
except asyncio.exceptions.CancelledError:
# We assume that if the task is cancelled, we are gracefully shutting
# down. This should only happen on program exit.
logger.info("Engine is gracefully shutting down.")
except Exception as e:
exception = e
logger.error("Engine background task failed", exc_info=e)
error_callback(exception)
raise AsyncEngineDeadError(
"Task finished unexpectedly. This should never happen! "
"Please open an issue on GitHub. See stack trace above for the "
"actual cause.") from e
STOP_ITERATION = Exception() # Sentinel
class AsyncStream:
"""A stream of RequestOutputs for a request that can be iterated over
asynchronously via an async generator."""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(
exception if self._is_raisable(exception) else STOP_ITERATION)
@property
def finished(self) -> bool:
return self._finished
async def generator(self) -> AsyncGenerator[RequestOutput, None]:
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
if result == STOP_ITERATION:
return
raise result
yield result
except GeneratorExit:
self._cancel(self.request_id)
raise asyncio.CancelledError from None
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
class RequestTracker:
"""Synchronous abstraction for tracking requests."""
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = asyncio.Event()
def __contains__(self, item):
return item in self._request_streams
def __len__(self) -> int:
return len(self._request_streams)
def propagate_exception(self,
exc: Exception,
request_id: Optional[str] = None) -> None:
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self.abort_request(request_id, exception=exc)
else:
# NB: tuple() used here because self.abort_request pops the stream
# out of self._request_streams, so we can't iterate on it directly
for rid in tuple(self._request_streams.keys()):
self.abort_request(rid, exception=exc)
def process_request_output(self,
request_output: RequestOutput,
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
finished = request_output.finished
if finished:
stream = self._request_streams.pop(request_id, None)
else:
stream = self._request_streams.get(request_id)
# Guard against a KeyError which can occur if the request was aborted
# while the output was generated
if stream is not None:
stream.put(request_output)
if finished:
stream.finish()
if verbose and finished:
logger.info("Finished request %s.", request_id)
def process_exception(self,
request_id: str,
exception: BaseException,
*,
verbose: bool = False) -> None:
"""Propagate an exception from the engine."""
if verbose:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id, exception=exception)
def add_request(self,
request_id: str,
*,
verbose: bool = False,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
abort_request = partial(self.abort_request, verbose=verbose)
stream = AsyncStream(request_id, abort_request)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
}))
self.new_requests_event.set()
if verbose:
logger.info("Added request %s.", request_id)
return stream
def abort_request(self,
request_id: str,
*,
exception: Optional[Union[BaseException,
Type[BaseException]]] = None,
verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info("Aborted request %s.", request_id)
self._aborted_requests.put_nowait(request_id)
stream = self._request_streams.pop(request_id, None)
if stream is not None:
stream.finish(exception=exception)
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[str] = set()
while not self._aborted_requests.empty():
request_id = self._aborted_requests.get_nowait()
finished_requests.add(request_id)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
request_id = stream.request_id
if request_id in finished_requests:
# The request has already been aborted.
stream.finish(asyncio.CancelledError)
finished_requests.discard(request_id)
else:
self._request_streams[request_id] = stream
new_requests.append(new_request)
return new_requests, finished_requests
async def wait_for_new_requests(self):
if not self.has_new_requests():
await self.new_requests_event.wait()
self.new_requests_event.clear()
def has_new_requests(self):
return not self._new_requests.empty()
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
# these are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
if not scheduler_outputs.is_empty():
# this will cause mamba_cache/minimax_cache failed
# to release finished_requests_ids of the last steps
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
else:
finished_requests_ids = list()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
if not scheduler_outputs.is_empty():
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
# Execute the model.
outputs = await self.model_executor.execute_model_async(
execute_model_req)
else:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
outputs = []
if not self._has_remaining_steps(seq_group_metadata_list):
# is_first_step_output is True only when the num_steps of all
# the sequences are 1.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True,
is_first_step_output=is_first_step_output)
if outputs and allow_async_output_proc:
assert len(
outputs
) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step(
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, outputs)
# Tracing
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
return ctx.request_outputs
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
return ctx.request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def get_tokenizer_async(self) -> AnyTokenizer:
return self.get_tokenizer()
async def add_request_async(
self,
request_id: str,
prompt: PromptType,
params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Async version of
[`add_request`][vllm.engine.llm_engine.LLMEngine.add_request].
"""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if priority != 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if arrival_time is None:
arrival_time = time.time()
if data_parallel_rank is not None:
raise ValueError("Targeting data_parallel_rank only supported "
"in v1 client.")
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
# We use the -2 dimension (instead of 0) in case a batched input
# of batch size 1 is passed in.
prompt["prompt_token_ids"] = [0
] * prompt["prompt_embeds"].shape[-2]
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
async def check_health_async(self) -> None:
self.model_executor.check_health()
async def collective_rpc_async(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
raise NotImplementedError
class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
make it asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
by the generate method when there are requests in the waiting queue. The
generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
to the caller.
Args:
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for [`LLMEngine`][vllm.LLMEngine].
**kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
def __init__(self,
*args: Any,
log_requests: bool = True,
start_engine_loop: bool = True,
**kwargs: Any) -> None:
if envs.VLLM_USE_V1:
raise ValueError(
"Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.log_requests = log_requests
self.engine = self._engine_class(*args, **kwargs)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
self.use_process_request_outputs_callback = (
self.engine.model_config.use_async_output_proc)
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
weak_bind(self.process_request_outputs)
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded: Optional[asyncio.Task] = None
self.start_engine_loop = start_engine_loop
self._errored_with: Optional[BaseException] = None
# Lazy initialized fields
self._request_tracker: RequestTracker
def __del__(self):
if rt := getattr(self, "request_tracker", None):
# Wake up engine loop so that it will exit cleanly
rt.new_requests_event.set()
@classmethod
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
return LLMEngine._get_executor_cls(engine_config)
@classmethod
@deprecate_kwargs(
"disable_log_requests",
additional_message=("This argument will have no effect. "
"Use `enable_log_requests` instead."),
)
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
enable_log_requests: bool = False,
disable_log_stats: bool = False,
disable_log_requests: bool = True, # Deprecated, will be removed
) -> "AsyncLLMEngine":
"""Create an AsyncLLMEngine from the EngineArgs."""
return cls(
vllm_config=vllm_config,
executor_class=cls._get_executor_cls(vllm_config),
start_engine_loop=start_engine_loop,
log_requests=enable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
vllm_config = engine_args.create_engine_config(usage_context)
async_engine_cls = cls
if envs.VLLM_USE_V1:
from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
async_engine_cls = V1AsyncLLMEngine
return async_engine_cls.from_vllm_config(
vllm_config=vllm_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
disable_log_stats=engine_args.disable_log_stats,
enable_log_requests=engine_args.enable_log_requests,
)
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and self._background_loop_unshielded is not None
and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None and
self._background_loop_unshielded is not None
and self._background_loop_unshielded.done())
@property
def errored(self) -> bool:
return self._errored_with is not None
@property
def dead_error(self) -> BaseException:
return AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc
def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor
async def get_tokenizer(self) -> AnyTokenizer:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None:
"""Start the background loop."""
if self.errored:
raise AsyncEngineDeadError(
"Background loop has errored already.") from self._errored_with
if self.is_running:
raise RuntimeError("Background loop is already running.")
# Initialize the RequestTracker here so it uses the right event loop.
self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop(weakref.ref(self)))
self._background_loop_unshielded.add_done_callback(
partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def shutdown_background_loop(self) -> None:
"""
Shut down the background loop.
This method needs to be called during cleanup to remove
references to `self` and properly GC the resources held
by the async LLM engine (e.g., the executors as well as
their resources).
"""
if self._background_loop_unshielded is not None:
self._background_loop_unshielded.cancel()
self._background_loop_unshielded = None
self.background_loop = None
async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests, aborted_requests = (
self._request_tracker.get_new_and_aborted_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
try:
await self.engine.add_request_async(**new_request)
except ValueError as e:
# TODO: use a vLLM specific error for failed validation
self._request_tracker.process_exception(
new_request["request_id"],
e,
verbose=self.log_requests,
)
if aborted_requests:
await self._engine_abort(aborted_requests)
request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams.
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if not self.use_process_request_outputs_callback:
all_finished = self.process_request_outputs(request_outputs)
else:
# For callback case, we only need to detect when all
# requests are finished
all_finished = all(request_output.finished
for request_output in request_outputs)
return not all_finished
def process_request_outputs(self, request_outputs) -> bool:
# Put the outputs into the corresponding streams.
all_finished = True
for request_output in request_outputs:
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
all_finished = all_finished and request_output.finished
return all_finished
async def _engine_abort(self, request_ids: Iterable[str]):
self.engine.abort_request(request_ids)
@staticmethod
async def run_engine_loop(engine_ref: ReferenceType):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine: Optional[AsyncLLMEngine] = engine_ref()
if not engine:
return
pipeline_parallel_size = \
engine.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True:
if not any(has_requests_in_progress):
logger.debug("Waiting for new requests...")
# Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting
# indefinitely in torch.distributed ops which may otherwise
# time out, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
await engine.engine.stop_remote_worker_execution_loop_async()
request_tracker = engine._request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del engine
await asyncio.sleep(0)
if engine_ref() is None:
return
await request_tracker.wait_for_new_requests()
engine = engine_ref()
if not engine:
return
logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(engine.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
done, _ = await asyncio.wait(
requests_in_progress,
return_when=asyncio.FIRST_COMPLETED)
for _ in range(pipeline_parallel_size):
await asyncio.sleep(0)
for task in done:
result = task.result()
virtual_engine = requests_in_progress.index(task)
has_unfinished_requests = (
engine.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
engine.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
engine.set_errored(exc)
raise
await asyncio.sleep(0)
async def add_request(
self,
request_id: str,
prompt: PromptType,
params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[RequestOutput, None]:
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
else:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if (priority != 0
and not self.engine.scheduler_config.policy == "priority"):
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
prompt=prompt,
params=params,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs,
)
return stream.generator()
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must
handle this request. Only applicable if DP is enabled.
Yields:
The output `RequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
[`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> # note that engine_args here is AsyncEngineArgs instance
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "prompt": "What is LLM?",
>>> "stream": False, # assume the non-streaming case
>>> "temperature": 0.0,
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.generate(
>>> example_input["prompt"],
>>> SamplingParams(temperature=example_input["temperature"]),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
try:
async for output in await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
):
yield LLMEngine.validate_output(output, RequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
raise NotImplementedError(
"Pooling models are not supported in vLLM V0")
async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
if not isinstance(request_id, str):
raise RuntimeError("Only single-request abort supported in"
" deprecated V0")
if not self.is_running:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
return self._abort(request_id)
def _abort(self, request_id: str) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
self._request_tracker.abort_request(request_id,
exception=asyncio.CancelledError,
verbose=self.log_requests)
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
return self.engine.get_vllm_config()
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
return self.engine.get_model_config()
async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
return self.engine.get_parallel_config()
async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
return self.engine.get_scheduler_config()
async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
return self.engine.get_lora_config()
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
self.engine.do_log_stats()
async def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.")
await self.engine.check_health_async()
logger.debug("Health check took %fs", time.perf_counter() - t)
async def is_tracing_enabled(self) -> bool:
return self.engine.is_tracing_enabled()
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
self.engine.add_logger(logger_name=logger_name, logger=logger)
def remove_logger(self, logger_name: str) -> None:
self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None:
self.engine.start_profile()
async def stop_profile(self) -> None:
self.engine.stop_profile()
async def reset_mm_cache(self) -> None:
self.engine.reset_mm_cache()
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
self.engine.reset_prefix_cache(device)
async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache()
self.engine.sleep(level)
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up(tags)
async def is_sleeping(self) -> bool:
return self.engine.is_sleeping()
async def add_lora(self, lora_request: LoRARequest) -> bool:
return self.engine.add_lora(lora_request)
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
"""
Perform a collective RPC call to the given path.
"""
return await self.engine.collective_rpc_async(method, timeout, args,
kwargs)
# TODO(v1): Remove this class proxy when V1 goes default.
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.async_llm import AsyncLLM
AsyncLLMEngine = AsyncLLM # type: ignore
...@@ -11,7 +11,6 @@ import uvicorn ...@@ -11,7 +11,6 @@ import uvicorn
from fastapi import FastAPI, Request, Response from fastapi import FastAPI, Request, Response
from vllm import envs from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
...@@ -154,7 +153,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: ...@@ -154,7 +153,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
""" """
@app.exception_handler(RuntimeError) @app.exception_handler(RuntimeError)
@app.exception_handler(AsyncEngineDeadError)
@app.exception_handler(EngineDeadError) @app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError) @app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __): async def runtime_exception_handler(request: Request, __):
......
...@@ -38,7 +38,6 @@ from typing_extensions import assert_never ...@@ -38,7 +38,6 @@ from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (load_chat_template, from vllm.entrypoints.chat_utils import (load_chat_template,
resolve_hf_chat_template, resolve_hf_chat_template,
...@@ -201,50 +200,34 @@ async def build_async_engine_client_from_engine_args( ...@@ -201,50 +200,34 @@ async def build_async_engine_client_from_engine_args(
vllm_config = engine_args.create_engine_config(usage_context=usage_context) vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# V1 AsyncLLM. # V1 AsyncLLM.
if envs.VLLM_USE_V1: assert envs.VLLM_USE_V1
if disable_frontend_multiprocessing:
logger.warning(
"V1 is enabled, but got --disable-frontend-multiprocessing. "
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
client_count = client_config.pop(
"client_count") if client_config else 1
client_index = client_config.pop(
"client_index") if client_config else 0
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
enable_log_requests=engine_args.enable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_count=client_count,
client_index=client_index)
# Don't keep the dummy data in memory
await async_llm.reset_mm_cache()
yield async_llm
finally:
if async_llm:
async_llm.shutdown()
# V0 AsyncLLM. if disable_frontend_multiprocessing:
else: logger.warning(
"V1 is enabled, but got --disable-frontend-multiprocessing. "
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
engine_client: Optional[EngineClient] = None from vllm.v1.engine.async_llm import AsyncLLM
try: async_llm: Optional[AsyncLLM] = None
engine_client = AsyncLLMEngine.from_vllm_config( client_count = client_config.pop("client_count") if client_config else 1
vllm_config=vllm_config, client_index = client_config.pop("client_index") if client_config else 0
usage_context=usage_context, try:
enable_log_requests=engine_args.enable_log_requests, async_llm = AsyncLLM.from_vllm_config(
disable_log_stats=engine_args.disable_log_stats) vllm_config=vllm_config,
yield engine_client usage_context=usage_context,
finally: enable_log_requests=engine_args.enable_log_requests,
if engine_client and hasattr(engine_client, "shutdown"): disable_log_stats=engine_args.disable_log_stats,
engine_client.shutdown() client_addresses=client_config,
client_count=client_count,
client_index=client_index)
# Don't keep the dummy data in memory
await async_llm.reset_mm_cache()
yield async_llm
finally:
if async_llm:
async_llm.shutdown()
async def validate_json_request(raw_request: Request): async def validate_json_request(raw_request: Request):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment