Unverified Commit 6ace6fba authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[V1] `AsyncLLM` Implementation (#9826)


Signed-off-by: default avatarNick Hill <nickhill@us.ibm.com>
Signed-off-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 08f93e74
......@@ -165,6 +165,14 @@ steps:
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
- label: V1 Test
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/v1
commands:
- pytest -v -s v1
- label: Examples Test # 15min
working_dir: "/vllm-workspace/examples"
#mirror_hardwares: [amd]
......
"""
This file test accuracy of the vLLM server via LMEval.
It uses local-completions, which interacts with vLLM
through the OAI API with N concurrent connections.
This simulates real work usage of the API and makes
sure that the zmq frontend mp RPC message passing and
AsyncLLMEngine are working correctly.
"""
import lm_eval
import pytest
from vllm.platforms import current_platform
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
NUM_CONCURRENT = 500
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
def run_test():
"""Run the end to end accuracy test."""
model_args = f"pretrained={MODEL_NAME},max_model_len=2048"
results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks="gsm8k",
batch_size="auto",
)
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="V1 is currently only supported on CUDA.")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
run_test()
def test_lm_eval_accuracy_v0_engine(monkeypatch):
"""Run with the V0 Engine."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
run_test()
......@@ -37,11 +37,11 @@ if current_platform.is_tpu():
MAX_WAIT_SECONDS = 600
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
def test_lm_eval_accuracy(more_args):
def run_test(more_args):
"""Run the end to end accuracy test."""
args = list(DEFAULT_ARGS)
args.extend(more_args)
print(f"Running with: {args}")
with RemoteOpenAIServer(
......@@ -64,3 +64,22 @@ def test_lm_eval_accuracy(more_args):
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="V1 currently only supported on CUDA")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
run_test([])
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
def test_lm_eval_accuracy_v0_engine(monkeypatch, more_args):
"""Run with the V0 Engine."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
run_test(more_args)
import asyncio
from typing import Tuple
import pytest
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
from vllm.v1.engine.async_llm import AsyncLLM
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
disable_log_requests=True)
async def generate(engine: AsyncLLM, request_id: str,
max_tokens: int) -> Tuple[int, str]:
count = 0
async for _ in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(
max_tokens=max_tokens, temperature=0)):
count += 1
await asyncio.sleep(0.)
return count, request_id
@pytest.mark.asyncio
async def test_load(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
NUM_REQUESTS = 10000
NUM_EXPECTED_TOKENS = 10
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
failed_request_id = None
tokens = None
for task in tasks:
num_generated_tokens, request_id = await task
if (num_generated_tokens != NUM_EXPECTED_TOKENS
and failed_request_id is None):
failed_request_id = request_id
tokens = num_generated_tokens
assert failed_request_id is None, (
f"{failed_request_id} generated {tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
engine.shutdown()
from typing import List
import pytest
from transformers import AutoTokenizer
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
FULL_STRINGS = [
"My name is Robert from Neural Magic and I love working on vLLM so much!",
"Red Hat is the best open source company by far across Linux, K8s, and AI.",
"Nick is the name of my brother in addition to my colleague from Red Hat.",
]
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS]
PROMPT_LEN = 5
PROMPT_TOKENS = [
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
]
GENERATION_TOKENS = [
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
]
PROMPT_STRINGS = [
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
for prompt_tokens in PROMPT_TOKENS
]
PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS]
GENERATION_STRINGS = [
text[prompt_len:]
for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN)
]
class MockEngineCore:
"""Mock outputs form premade tokens lists."""
def __init__(self, tokens_list: List[List[int]]):
self.tokens_list = tokens_list
self.current_idx = 0
def get_outputs(self) -> List[EngineCoreOutput]:
token_idx = self.current_idx
self.current_idx += 1
outputs = []
for req_idx, token_ids in enumerate(self.tokens_list):
if len(token_ids) > token_idx:
output = EngineCoreOutput(request_id=f"request-{req_idx}",
new_token_ids=[token_ids[token_idx]],
finished=False)
if token_idx == len(token_ids) - 1:
output.finished = True
output.finish_reason = "stopped"
outputs.append(output)
return outputs
@pytest.mark.parametrize(
"request_output_kind",
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
def test_incremental_detokenization(request_output_kind: RequestOutputKind):
detokenizer = Detokenizer(TOKENIZER_NAME)
engine_core = MockEngineCore(GENERATION_TOKENS)
# Make N requests.
requests = [
DetokenizerRequest(
request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens,
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
) for idx, (
prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
]
# Add requests to the detokenizer.
for request in requests:
detokenizer.add_request(request)
gen_strings = {}
gen_tokens = {}
while True:
# Mock output from the EngineCore.
outputs = engine_core.get_outputs()
if len(outputs) == 0:
break
# Step the Detokenizer.
request_outputs, requests_to_abort = detokenizer.step(outputs)
assert len(requests_to_abort) == 0
# Update tracking.
for request_output in request_outputs:
request_id = request_output.request_id
new_text = request_output.outputs[0].text
new_tokens = request_output.outputs[0].token_ids
if request_id not in gen_strings:
gen_strings[request_id] = new_text
gen_tokens[request_id] = new_tokens
else:
gen_strings[request_id] += new_text
gen_tokens[request_id].extend(new_tokens)
# Confirmed tracked values matches what we expected.
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
zip(GENERATION_STRINGS, GENERATION_TOKENS)):
gen_str = gen_strings[f"request-{idx}"]
gen_toks = gen_tokens[f"request-{idx}"]
assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}"
assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}"
assert detokenizer.get_num_unfinished_requests() == 0
assert not detokenizer.has_unfinished_requests()
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
def test_stop_string(include_stop_str_in_output: bool):
detokenizer = Detokenizer(TOKENIZER_NAME)
engine_core = MockEngineCore(GENERATION_TOKENS)
# Make N requests.
requests = [
DetokenizerRequest(
request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens,
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA,
stop=STOP_STRINGS,
include_stop_str_in_output=include_stop_str_in_output,
) for idx, (
prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
]
# Add requests to the detokenizer.
for request in requests:
detokenizer.add_request(request)
gen_strings = {}
aborted = []
while True:
# Mock output from the EngineCore.
outputs = engine_core.get_outputs()
if len(outputs) == 0:
break
# Step the Detokenizer.
request_outputs, requests_to_abort = detokenizer.step(outputs)
for request_output in request_outputs:
# If aborted, we should not get a request output.
assert request_output.request_id not in aborted
aborted.extend(requests_to_abort)
# Update tracking.
for request_output in request_outputs:
if request_output.finished:
assert request_output.outputs[0].finish_reason == "stop"
request_id = request_output.request_id
new_text = request_output.outputs[0].text
if request_id not in gen_strings:
gen_strings[request_id] = new_text
else:
gen_strings[request_id] += new_text
# Confirmed tracked values matches what we expected.
for idx, (ref_gen_str,
stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)):
# Request should be aborted.
request_id = f"request-{idx}"
assert request_id in aborted
# Collected values that were generated.
gen_str = gen_strings[request_id]
# Construct reference strings.
stop_str_idx = ref_gen_str.find(stop_str)
ref_str_exc_stop = ref_gen_str[:stop_str_idx]
ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str
if include_stop_str_in_output:
assert gen_str == ref_str_inc_stop, (
f"{gen_str=}, {ref_str_inc_stop=}")
else:
assert gen_str == ref_str_exc_stop, (
f"{gen_str=}, {ref_str_exc_stop=}")
assert detokenizer.get_num_unfinished_requests() == 0
assert not detokenizer.has_unfinished_requests()
import time
import uuid
import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core import EngineCore
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request() -> EngineCoreRequest:
return EngineCoreRequest(
request_id=uuid.uuid4(),
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS,
sampling_params=SamplingParams(),
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
)
def test_engine_core(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
usage_context=UsageContext.UNKNOWN_CONTEXT)
"""Test basic request lifecycle."""
# First request.
engine_core.add_request(make_request())
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 1
# Second request.
engine_core.add_request(make_request())
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 1
_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
# Add two requests in a row.
engine_core.add_request(make_request())
engine_core.add_request(make_request())
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 2
_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 4
# Loop through until they are all done.
while len(engine_core.step()) > 0:
pass
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
"""Test abort cycle."""
# Basic abort.
req = make_request()
request_id = req.request_id
engine_core.add_request(req)
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 1
engine_core.abort_requests([request_id])
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
# Add, step, abort 1 of the 3.
req0 = make_request()
req1 = make_request()
req2 = make_request()
engine_core.add_request(req0)
engine_core.add_request(req1)
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 0
_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
engine_core.add_request(req2)
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 2
_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 3
# Abort just one.
engine_core.abort_requests([req1.request_id])
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
_ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
# Abort the other requests at the same time.
engine_core.abort_requests([req2.request_id, req0.request_id])
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
import asyncio
import time
import uuid
from typing import Dict, List
import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import EngineCoreClient
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request(params: SamplingParams) -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS,
sampling_params=params,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
)
def loop_until_done(client: EngineCoreClient, outputs: Dict):
while True:
engine_core_outputs = client.get_output()
if len(engine_core_outputs) == 0:
break
all_finished = True
for out in engine_core_outputs:
outputs[out.request_id].append(out)
if not out.finished:
all_finished = False
if all_finished:
break
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
while True:
engine_core_outputs = await client.get_output_async()
if len(engine_core_outputs) == 0:
break
all_finished = True
for out in engine_core_outputs:
outputs[out.request_id].append(out)
if not out.finished:
all_finished = False
if all_finished:
break
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
executor_class,
UsageContext.UNKNOWN_CONTEXT,
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
)
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests]
# Add requests to the engine.
for request in requests:
client.add_request(request)
time.sleep(0.01)
outputs: Dict[str, List] = {req_id: [] for req_id in request_ids}
loop_until_done(client, outputs)
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle."""
# Note: this code pathway will only work for multiprocessing
# since we have to call get_output() explicitly
# Add requests to the engine.
for idx, request in enumerate(requests):
client.add_request(request)
time.sleep(0.01)
if idx % 2 == 0:
client.abort_requests([request.request_id])
outputs = {req_id: [] for req_id in request_ids}
loop_until_done(client, outputs)
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Abort after request is finished."""
# Note: this code pathway will only work for multiprocessing
# since we have to call get_output() explicitly
request = requests[0]
client.add_request(request)
time.sleep(10.)
client.abort_requests([request.request_id])
# Shutdown the client.
client.shutdown()
@pytest.mark.asyncio
async def test_engine_core_client_asyncio(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
vllm_config,
executor_class,
UsageContext.UNKNOWN_CONTEXT,
multiprocess_mode=True,
asyncio_mode=True,
)
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests]
# Add requests to the engine.
for request in requests:
await client.add_request_async(request)
await asyncio.sleep(0.01)
outputs: Dict[str, List] = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle."""
# Add requests to the engine.
for idx, request in enumerate(requests):
await client.add_request_async(request)
await asyncio.sleep(0.01)
if idx % 2 == 0:
await client.abort_requests_async([request.request_id])
outputs = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
# Shutdown the client.
client.shutdown()
......@@ -2106,3 +2106,44 @@ class VllmConfig:
self.model_config is not None and self.load_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)
def __str__(self):
return ("model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s") % \
(self.model_config.model, self.speculative_config,
self.model_config.tokenizer,
self.model_config.skip_tokenizer_init,
self.model_config.tokenizer_mode,
self.model_config.revision,
self.model_config.override_neuron_config,
self.model_config.tokenizer_revision,
self.model_config.trust_remote_code,
self.model_config.dtype,
self.model_config.max_model_len,
self.load_config.download_dir,
self.load_config.load_format,
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size,
self.parallel_config.disable_custom_all_reduce,
self.model_config.quantization,
self.model_config.enforce_eager,
self.cache_config.cache_dtype,
self.model_config.quantization_param_path,
self.device_config.device, self.decoding_config,
self.observability_config, self.model_config.seed,
self.model_config.served_model_name,
self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
\ No newline at end of file
......@@ -6,7 +6,6 @@ from typing import Iterator, List, Optional, Union
import cloudpickle
import zmq
import vllm.envs
from vllm import AsyncEngineArgs, SamplingParams
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
......@@ -113,17 +112,9 @@ class MQLLMEngine:
load_general_plugins()
engine_config = engine_args.create_engine_config()
if vllm.envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
engine_class = V1LLMEngine
else:
engine_class = LLMEngine
executor_class = engine_class._get_executor_cls(engine_config)
executor_class = LLMEngine._get_executor_cls(engine_config)
use_async_sockets = (engine_config.model_config.use_async_output_proc
and not vllm.envs.VLLM_USE_V1)
use_async_sockets = engine_config.model_config.use_async_output_proc
return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
......
from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
......@@ -67,9 +67,13 @@ class StopChecker:
return
# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
......@@ -85,33 +89,40 @@ class StopChecker:
return
@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
def check_stop_strings(
output_text: str,
new_char_count: int,
stop: List[str],
include_in_output: bool,
) -> Optional[Tuple[str, int]]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
Returns tuple (stop_string, offset) if matched or else None.
Where stop_string is the matched stop string and offset is the
length to which output_text should be truncated, or -1 for no
truncation.
"""
if not new_char_count or not sampling_params.stop:
if not new_char_count or not stop:
return None
for stop_str in sampling_params.stop:
for stop_str in stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
stop_index = output_text.find(stop_str,
-new_char_count - stop_string_len)
if stop_index == -1:
continue
if sampling_params.include_stop_str_in_output:
if include_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
if stop_index >= len(output_text):
# No truncation required.
return stop_str
return stop_str, -1
# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return stop_str, stop_index
return None
......@@ -210,8 +210,11 @@ class LLM:
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self.engine_class = self.get_engine_class()
# TODO(rob): enable mp by default (issue with fork vs spawn)
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
@staticmethod
......
......@@ -26,7 +26,6 @@ from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
......@@ -61,6 +60,11 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
from vllm.version import __version__ as VLLM_VERSION
if envs.VLLM_USE_V1:
from vllm.v1.engine.async_llm import AsyncLLMEngine # type: ignore
else:
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
TIMEOUT_KEEP_ALIVE = 5 # seconds
prometheus_multiproc_dir: tempfile.TemporaryDirectory
......@@ -126,7 +130,8 @@ async def build_async_engine_client_from_engine_args(
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or disable_frontend_multiprocessing):
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config()
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)
......@@ -143,6 +148,8 @@ async def build_async_engine_client_from_engine_args(
None, build_engine)
yield engine_client
if hasattr(engine_client, "shutdown"):
engine_client.shutdown()
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
......
......@@ -72,6 +72,7 @@ if TYPE_CHECKING:
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
def get_default_cache_root():
......@@ -473,6 +474,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, use the V1 code path.
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
}
# end-env-vars-definition
......
......@@ -113,6 +113,36 @@ class RequestOutput:
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod
def new(
cls,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
text: str,
token_ids: List[int],
finished: bool = False,
) -> "RequestOutput":
"""Initialize a new RequestOutput object."""
# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text=text,
token_ids=token_ids,
cumulative_logprob=None,
logprobs=None, # TODO
)
return RequestOutput(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, # TODO
outputs=[completion_output],
finished=finished,
)
@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
......
......@@ -70,7 +70,7 @@ class KVCacheManager:
Args:
request: The request to get the computed blocks.
Returns:
A list of blocks that are computed for the request.
"""
......@@ -105,7 +105,7 @@ class KVCacheManager:
Args:
request: The request to append slots.
num_tokens: The number of tokens to append.
Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
......@@ -176,7 +176,7 @@ class KVCacheManager:
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: The blocks that have already been computed.
Returns:
A list of new allocated blocks.
"""
......@@ -240,7 +240,8 @@ class KVCacheManager:
Args:
request: The request to free the blocks.
"""
blocks = self.req_to_blocks.pop(request.request_id)
# Default to [] in case a request is freed (aborted) before alloc.
blocks = self.req_to_blocks.pop(request.request_id, [])
if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are
# freed first.
......@@ -259,13 +260,13 @@ class KVCacheManager:
"""Get new blocks from the free block pool, and add token IDs to
allocated blocks if caching is enabled.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
token_ids: The token IDs in the blocks. None if caching is disabled.
parent_block: The parent block. Used to include block chain
in the block hash.
Returns:
A list of new block.
"""
......
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Deque, Dict, Iterable, List, Optional, Set, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
......@@ -237,13 +238,12 @@ class Scheduler:
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[Tuple[Request, int]]:
) -> List[EngineCoreOutput]:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
# (request, num_sampled_tokens)
sampled: List[Tuple[Request, int]] = []
engine_core_outputs: List[EngineCoreOutput] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
......@@ -257,17 +257,29 @@ class Scheduler:
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.append_output_token_ids(token_id)
sampled.append((request, 1))
num_new_tokens = 1
# TODO: Update the KV cache manager for prefix caching.
# Check if the request is finished.
# Check for stop and update request state.
# This must be called before me make the EngineCoreOutput.
stopped = self._check_stop(request)
# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
request_id=req_id,
new_token_ids=request.output_token_ids[-num_new_tokens:],
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
engine_core_outputs.append(output)
# Breakout of the loop.
if stopped:
continue
new_running.append(request)
self.running = new_running
return sampled
return engine_core_outputs
def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
......
import enum
from dataclasses import dataclass
from typing import List, Optional, Union
import msgspec
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
@dataclass
class DetokenizerRequest:
request_id: str
prompt: Optional[str]
prompt_token_ids: List[int]
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_kind: RequestOutputKind
stop: List[str]
include_stop_str_in_output: bool
class EngineCoreRequest(msgspec.Struct, omit_defaults=True):
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
# due to circular imports and typing we have in data.py
request_id: str
#NOTE(Nick): I don't think we need to pass prompt here since it should
# always be tokenized?
prompt: Optional[str]
prompt_token_ids: List[int]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
class EngineCoreOutput(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
request_id: str
new_token_ids: List[int]
finished: bool
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
class EngineCoreOutputs(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
# [num_reqs]
outputs: List[EngineCoreOutput]
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment