Unverified Commit 2eedede8 authored by Megha Agarwal's avatar Megha Agarwal Committed by GitHub
Browse files

[Core] Asynchronous Output Processor (#7049)


Co-authored-by: default avatarAlexander Matveev <alexm@neuralmagic.com>
parent 015e6cc2
......@@ -86,6 +86,7 @@ def run_vllm(
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
......@@ -110,6 +111,7 @@ def run_vllm(
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
)
# Add the requests to the engine.
......@@ -237,7 +239,8 @@ def main(args: argparse.Namespace):
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format)
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
......@@ -418,6 +421,11 @@ if __name__ == "__main__":
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
"--disable-async-output-proc",
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
......
......@@ -88,6 +88,9 @@ def test_models(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
# Due to low-precision numerical divergence, this test is too sensitive to
# the async postprocessor
@pytest.mark.parametrize("disable_async_output_proc", [True])
def test_models_with_fp8_kv_cache(
vllm_runner,
example_prompts,
......@@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache(
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
disable_async_output_proc: bool,
) -> None:
"""
Only checks log probs match between chunked-prefill and
......@@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs,
) as vllm_model:
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
......@@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs,
) as vllm_model:
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
......
......@@ -209,7 +209,6 @@ def test_swap_infeasible(
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]
with vllm_runner(
model,
dtype=dtype,
......
......@@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
......@@ -180,7 +180,7 @@ def test_maximal_decoding():
"""Verify decoding requests are prioritized."""
block_size = 4
max_seqs = 2
max_model_len = 2
max_model_len = 8
max_num_batched_tokens = 2
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
......
......@@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
......
......@@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams
MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200
IS_ASYNC = False
@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
......@@ -14,99 +16,148 @@ def vllm_model(vllm_runner):
yield vllm_model
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False,
use_async_output_proc: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)
output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
if use_async_output_proc:
llm_engine.step()
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason
assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
def _set_async_mode(llm_engine, is_async):
llm_engine.scheduler[0].use_async_output_proc = is_async
def _stop_basic(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".")
expected_reason=".",
use_async_output_proc=is_async)
_test_stopping(vllm_model.model.llm_engine,
_test_stopping(llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".")
expected_reason=".",
use_async_output_proc=is_async)
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
def _stop_multi_tokens(llm_engine, is_async):
_test_stopping(
vllm_model.model.llm_engine,
llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo")
expected_reason="group of peo",
use_async_output_proc=is_async)
_test_stopping(
vllm_model.model.llm_engine,
llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo")
expected_reason="group of peo",
use_async_output_proc=is_async)
@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
def _stop_partial_token(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani")
expected_reason="gani",
use_async_output_proc=is_async)
_test_stopping(vllm_model.model.llm_engine,
_test_stopping(llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani")
expected_reason="gani",
use_async_output_proc=is_async)
@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
def _stop_token_id(llm_engine, is_async):
# token id 13013 => " organization"
_test_stopping(vllm_model.model.llm_engine,
_test_stopping(llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013)
expected_reason=13013,
use_async_output_proc=is_async)
_test_stopping(vllm_model.model.llm_engine,
_test_stopping(llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013)
expected_reason=13013,
use_async_output_proc=is_async)
def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_basic(vllm_model.model.llm_engine, is_async=True)
output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_basic(vllm_model.model.llm_engine, is_async=False)
# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason
assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_partial_token(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_partial_token(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_token_id(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_token_id(vllm_model.model.llm_engine, is_async=False)
......@@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]
# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if eager_mode:
ms_server_args.append("--enforce-eager")
......
......@@ -140,6 +140,7 @@ class ModelConfig:
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
) -> None:
self.model = model
self.tokenizer = tokenizer
......@@ -172,6 +173,7 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
# Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None)
......@@ -326,6 +328,49 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
# Nothing to check
return
if parallel_config.pipeline_parallel_size > 1:
logger.warning("Async output processing can not be enabled "
"with pipeline parallel")
self.use_async_output_proc = False
return
if device_config.device_type != "cuda":
logger.warning(
"Async output processing is only supported for CUDA."
" Disabling it for other platforms.")
self.use_async_output_proc = False
return
if envs.VLLM_USE_RAY_SPMD_WORKER:
logger.warning(
"Async output processing can not be enabled with ray spmd")
self.use_async_output_proc = False
return
if self.enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
self.use_async_output_proc = not self.enforce_eager
return
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.embedding_mode:
self.use_async_output_proc = False
if speculative_config:
logger.warning("Async output processing is not supported with"
" speculative decoding currently.")
self.use_async_output_proc = False
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
......@@ -358,6 +403,11 @@ class ModelConfig:
"fallback to the eager mode.")
self.enforce_eager = True
if pipeline_parallel_size > 1 and self.use_async_output_proc:
logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""
......@@ -1769,6 +1819,9 @@ class EngineConfig:
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
......
......@@ -4,7 +4,8 @@ import random
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
......@@ -299,6 +300,7 @@ class Scheduler:
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback_fn: Optional[Callable] = None,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
......@@ -364,10 +366,36 @@ class Scheduler:
self.num_cumulative_preemption: int = 0
# Used to cache python objects
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
scheduler_running_outputs_builder)
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
scheduled_seq_group_builder)
self._seq_group_metadata_cache: List[PyObjectCache] = []
self._scheduler_running_outputs_cache: List[PyObjectCache] = []
self._scheduled_seq_group_cache: List[PyObjectCache] = []
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self.output_proc_callback_fn = output_proc_callback_fn
self.use_async_output_proc = self.output_proc_callback_fn is not None
self.num_cache_iters = 2 if self.use_async_output_proc else 1
self.cache_id = 0
for i in range(self.num_cache_iters):
self._seq_group_metadata_cache.append(
PyObjectCache(seq_group_metadata_builder))
self._scheduler_running_outputs_cache.append(
PyObjectCache(scheduler_running_outputs_builder))
self._scheduled_seq_group_cache.append(
PyObjectCache(scheduled_seq_group_builder))
# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []
@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
@property
def lora_enabled(self) -> bool:
......@@ -483,7 +511,7 @@ class Scheduler:
SchedulerRunningOutputs.
"""
ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache.get_object()
self._scheduler_running_outputs_cache[self.cache_id].get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
......@@ -510,8 +538,12 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
running_queue = self.running
# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
......@@ -521,6 +553,28 @@ class Scheduler:
break
running_queue.popleft()
# With async postprocessor, an extra decode run is done
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if self.use_async_output_proc and seq_group.seqs[0].get_len(
) > self.scheduler_config.max_model_len:
self._async_stopped.append(seq_group)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback_fn is not None
self.output_proc_callback_fn(is_async=True)
self.running = tmp
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
......@@ -556,7 +610,7 @@ class Scheduler:
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache.get_object()
self._scheduled_seq_group_cache[self.cache_id].get_object()
scheduled_seq_group.seq_group = seq_group
if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens
......@@ -579,8 +633,8 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
self._scheduler_running_outputs_cache.reset()
self._scheduled_seq_group_cache.reset()
self._scheduler_running_outputs_cache[self.next_cache_id].reset()
self._scheduled_seq_group_cache[self.next_cache_id].reset()
return ret
......@@ -1031,17 +1085,31 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
)
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
no_beam_search = (seq_group.sampling_params.best_of == 1
and not seq_group.sampling_params.use_beam_search)
return no_beam_search
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.perf_counter()
scheduler_outputs = self._schedule()
now = time.time()
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
# TODO: Combine multi-step and async postprocessor
allow_async_output_proc: bool = (
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, scheduled_seq_group in enumerate(
......@@ -1050,6 +1118,11 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache[
self.cache_id].get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
......@@ -1139,6 +1212,10 @@ class Scheduler:
)
seq_group_metadata_list.append(seq_group_metadata)
if allow_async_output_proc:
allow_async_output_proc = self._allow_async_output_proc(
seq_group)
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
......@@ -1147,6 +1224,8 @@ class Scheduler:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
self._seq_group_metadata_cache[self.next_cache_id].reset()
scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
......@@ -1158,7 +1237,12 @@ class Scheduler:
else:
seq_group.metrics.scheduler_time = scheduler_time
return seq_group_metadata_list, scheduler_outputs
# Move to next cache (if exists)
self.cache_id = self.next_cache_id
# Return results
return (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq)
......@@ -1167,6 +1251,12 @@ class Scheduler:
"""Free a sequence from a block table."""
self.block_manager.free(seq)
def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
"""Free finished seqs in a sequence group."""
for seq in seq_group.get_seqs():
if seq.is_finished():
self.free_seq(seq)
def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
......@@ -1179,8 +1269,24 @@ class Scheduler:
self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group)
# Free finished seqs
self._free_finished_seqs(seq_group)
self.running = remaining
# Handle async stopped sequence groups
# (ones that reached max model len)
if self._async_stopped:
for seq_group in self._async_stopped:
self._free_seq_group_cross_attn_blocks(seq_group)
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
self._async_stopped.clear()
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
......
......@@ -147,6 +147,7 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
def __post_init__(self):
if self.tokenizer is None:
......@@ -733,6 +734,12 @@ class EngineArgs:
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.")
parser.add_argument(
'--disable-async-output-proc',
action='store_true',
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
return parser
@classmethod
......@@ -792,6 +799,7 @@ class EngineArgs:
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
)
cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else
......
......@@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine):
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
......@@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine):
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
......@@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
output = []
# Finish the current step for all the sequence groups.
......@@ -337,19 +358,32 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
else:
request_outputs = []
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Cache results in engine
self.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
# Tracing
self.do_tracing(scheduler_outputs)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
else:
self.request_outputs = []
return request_outputs
return self.request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
......
This diff is collapsed.
......@@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC):
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
return SingleStepOutputProcessor(
scheduler_config,
detokenizer,
scheduler,
seq_counter,
stop_checker,
)
return SingleStepOutputProcessor(scheduler_config, detokenizer,
scheduler, seq_counter,
stop_checker)
else:
# Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import (
......@@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC):
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
......
......@@ -57,17 +57,28 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
def process_outputs(self,
sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization,
including freeing finished sequences. It also handles cases where there
are tokens emitted after the EOS token.
This applies logic like stop condition checking and detokenization.
It also handles cases where there are tokens emitted after
the EOS token.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
# TODO: Add support for async if necessary
assert not is_async
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
assert seqs, "expected running sequences"
......@@ -138,7 +149,3 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
)
if seq.is_finished():
break
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
......@@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
that is currently difficult to schedule multiple steps ahead of time.
"""
def __init__(
self,
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
stop_checker: StopChecker,
):
def __init__(self, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, scheduler: List[Scheduler],
seq_counter: Counter, stop_checker: StopChecker):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
......@@ -44,16 +39,24 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0])
return self._process_sequence_group_outputs(sequence_group, outputs[0],
is_async)
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
......@@ -80,14 +83,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
seq_group.prompt_logprobs.extend(prompt_logprobs)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.n == 1 and not sampling_params.use_beam_search:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
......@@ -104,6 +109,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler.free_seq(seq)
return
# TODO: Add support for async for beam search
assert not is_async
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
......
......@@ -129,6 +129,7 @@ class LLM:
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
**kwargs,
) -> None:
'''
......@@ -170,6 +171,7 @@ class LLM:
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
......@@ -603,7 +605,6 @@ class LLM:
inputs = [inputs]
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
......@@ -678,6 +679,10 @@ class LLM:
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
)
# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
......@@ -700,6 +705,10 @@ class LLM:
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
# Restore original behavior
self.llm_engine.step_return_finished_only = False
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
......
......@@ -64,8 +64,9 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
......@@ -188,7 +189,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@abstractmethod
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
......
......@@ -176,5 +176,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
execute_model_req: ExecuteModelRequest,
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
)(execute_model_req=execute_model_req)
return output
......@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Union, cast)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
Optional, Set, Tuple, Union, cast)
import msgspec
import torch
......@@ -474,11 +474,8 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
def append_token_id(self, token_id: int, logprobs: Dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
......@@ -1293,6 +1290,8 @@ class ExecuteModelRequest(
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async postprocessor
output_proc_callback_fn: Optional[Callable] = None
@property
def is_first_multi_step(self) -> bool:
......@@ -1338,4 +1337,5 @@ class ExecuteModelRequest(
num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None)
if self.last_sampled_token_ids is not None else None,
output_proc_callback_fn=self.output_proc_callback_fn)
......@@ -6,8 +6,8 @@ import time
import warnings
import weakref
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Tuple, Type, TypeVar, Union)
import numpy as np
import torch
......@@ -90,6 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
output_proc_callback_fn: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
......@@ -1327,7 +1328,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
......@@ -1451,6 +1452,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if not self.is_driver_worker:
return []
if model_input.output_proc_callback_fn is not None:
model_input.output_proc_callback_fn(is_async=True)
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment