Unverified Commit 2eedede8 authored by Megha Agarwal's avatar Megha Agarwal Committed by GitHub
Browse files

[Core] Asynchronous Output Processor (#7049)


Co-authored-by: default avatarAlexander Matveev <alexm@neuralmagic.com>
parent 015e6cc2
...@@ -86,6 +86,7 @@ def run_vllm( ...@@ -86,6 +86,7 @@ def run_vllm(
use_v2_block_manager: bool = False, use_v2_block_manager: bool = False,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format, load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
...@@ -110,6 +111,7 @@ def run_vllm( ...@@ -110,6 +111,7 @@ def run_vllm(
load_format=load_format, load_format=load_format,
num_scheduler_steps=num_scheduler_steps, num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager, use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
) )
# Add the requests to the engine. # Add the requests to the engine.
...@@ -237,7 +239,8 @@ def main(args: argparse.Namespace): ...@@ -237,7 +239,8 @@ def main(args: argparse.Namespace):
args.enable_prefix_caching, args.enable_chunked_prefill, args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend, args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps, args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format) args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
...@@ -418,6 +421,11 @@ if __name__ == "__main__": ...@@ -418,6 +421,11 @@ if __name__ == "__main__":
'section for more information.\n' 'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes ' '* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n') 'quantization.\n')
parser.add_argument(
"--disable-async-output-proc",
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
......
...@@ -88,6 +88,9 @@ def test_models( ...@@ -88,6 +88,9 @@ def test_models(
# NOTE: Increasing this in this suite will fail CI because we currently cannot # NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test. # reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("tensor_parallel_size", [1])
# Due to low-precision numerical divergence, this test is too sensitive to
# the async postprocessor
@pytest.mark.parametrize("disable_async_output_proc", [True])
def test_models_with_fp8_kv_cache( def test_models_with_fp8_kv_cache(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
...@@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache( ...@@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache(
chunked_prefill_token_size: int, chunked_prefill_token_size: int,
enforce_eager: bool, enforce_eager: bool,
tensor_parallel_size: int, tensor_parallel_size: int,
disable_async_output_proc: bool,
) -> None: ) -> None:
""" """
Only checks log probs match between chunked-prefill and Only checks log probs match between chunked-prefill and
...@@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache( ...@@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs, **extra_kwargs,
) as vllm_model: ) as vllm_model:
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
...@@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache( ...@@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
**extra_kwargs, **extra_kwargs,
) as vllm_model: ) as vllm_model:
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
......
...@@ -209,7 +209,6 @@ def test_swap_infeasible( ...@@ -209,7 +209,6 @@ def test_swap_infeasible(
prefill_blocks = 2 prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1] example_prompts = example_prompts[:1]
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
......
...@@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int): ...@@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
def schedule_and_update_computed_tokens(scheduler): def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule() metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas): for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size) s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out return metas, out
...@@ -180,7 +180,7 @@ def test_maximal_decoding(): ...@@ -180,7 +180,7 @@ def test_maximal_decoding():
"""Verify decoding requests are prioritized.""" """Verify decoding requests are prioritized."""
block_size = 4 block_size = 4
max_seqs = 2 max_seqs = 2
max_model_len = 2 max_model_len = 8
max_num_batched_tokens = 2 max_num_batched_tokens = 2
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs, max_seqs,
......
...@@ -199,7 +199,7 @@ def append_new_token(out, token_id: int): ...@@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):
def schedule_and_update_computed_tokens(scheduler): def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule() metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas): for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size) s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out return metas, out
......
...@@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams ...@@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams
MODEL = "meta-llama/llama-2-7b-hf" MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200 MAX_TOKENS = 200
IS_ASYNC = False
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def vllm_model(vllm_runner): def vllm_model(vllm_runner):
...@@ -14,99 +16,148 @@ def vllm_model(vllm_runner): ...@@ -14,99 +16,148 @@ def vllm_model(vllm_runner):
yield vllm_model yield vllm_model
@pytest.mark.skip_global_cleanup def _test_stopping(llm_engine: LLMEngine,
def test_stop_basic(vllm_model): expected_output: str,
_test_stopping(vllm_model.model.llm_engine, expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False,
use_async_output_proc: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)
output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
if use_async_output_proc:
llm_engine.step()
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason
assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
def _set_async_mode(llm_engine, is_async):
llm_engine.scheduler[0].use_async_output_proc = is_async
def _stop_basic(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["."], stop=["."],
include_in_output=False, include_in_output=False,
expected_output="VLLM is a 100% volunteer organization", expected_output="VLLM is a 100% volunteer organization",
expected_reason=".") expected_reason=".",
use_async_output_proc=is_async)
_test_stopping(vllm_model.model.llm_engine, _test_stopping(llm_engine,
stop=["."], stop=["."],
include_in_output=True, include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.", expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".") expected_reason=".",
use_async_output_proc=is_async)
@pytest.mark.skip_global_cleanup def _stop_multi_tokens(llm_engine, is_async):
def test_stop_multi_tokens(vllm_model):
_test_stopping( _test_stopping(
vllm_model.model.llm_engine, llm_engine,
stop=["group of peo", "short"], stop=["group of peo", "short"],
include_in_output=False, include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ", expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo") expected_reason="group of peo",
use_async_output_proc=is_async)
_test_stopping( _test_stopping(
vllm_model.model.llm_engine, llm_engine,
stop=["group of peo", "short"], stop=["group of peo", "short"],
include_in_output=True, include_in_output=True,
expected_output= expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo", "VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo") expected_reason="group of peo",
use_async_output_proc=is_async)
@pytest.mark.skip_global_cleanup def _stop_partial_token(llm_engine, is_async):
def test_stop_partial_token(vllm_model): _test_stopping(llm_engine,
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"], stop=["gani"],
include_in_output=False, include_in_output=False,
expected_output="VLLM is a 100% volunteer or", expected_output="VLLM is a 100% volunteer or",
expected_reason="gani") expected_reason="gani",
use_async_output_proc=is_async)
_test_stopping(vllm_model.model.llm_engine, _test_stopping(llm_engine,
stop=["gani"], stop=["gani"],
include_in_output=True, include_in_output=True,
expected_output="VLLM is a 100% volunteer organi", expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani") expected_reason="gani",
use_async_output_proc=is_async)
@pytest.mark.skip_global_cleanup def _stop_token_id(llm_engine, is_async):
def test_stop_token_id(vllm_model):
# token id 13013 => " organization" # token id 13013 => " organization"
_test_stopping(vllm_model.model.llm_engine, _test_stopping(llm_engine,
stop_token_ids=[13013], stop_token_ids=[13013],
include_in_output=False, include_in_output=False,
expected_output="VLLM is a 100% volunteer", expected_output="VLLM is a 100% volunteer",
expected_reason=13013) expected_reason=13013,
use_async_output_proc=is_async)
_test_stopping(vllm_model.model.llm_engine, _test_stopping(llm_engine,
stop_token_ids=[13013], stop_token_ids=[13013],
include_in_output=True, include_in_output=True,
expected_output="VLLM is a 100% volunteer organization", expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013) expected_reason=13013,
use_async_output_proc=is_async)
def _test_stopping(llm_engine: LLMEngine, @pytest.mark.skip_global_cleanup
expected_output: str, def test_stop_basic(vllm_model):
expected_reason: Any, _set_async_mode(vllm_model.model.llm_engine, True)
stop: Optional[List[str]] = None, _stop_basic(vllm_model.model.llm_engine, is_async=True)
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)
output: Optional[CompletionOutput] = None _set_async_mode(vllm_model.model.llm_engine, False)
output_text = "" _stop_basic(vllm_model.model.llm_engine, is_async=False)
stop_reason = None
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason
assert output is not None @pytest.mark.skip_global_cleanup
assert output_text == expected_output def test_stop_multi_tokens(vllm_model):
assert stop_reason == expected_reason _set_async_mode(vllm_model.model.llm_engine, True)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_partial_token(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_partial_token(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_token_id(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_token_id(vllm_model.model.llm_engine, is_async=False)
...@@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ...@@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \ ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"] ["--num-scheduler-steps", f"{num_scheduler_steps}"]
# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if eager_mode: if eager_mode:
ms_server_args.append("--enforce-eager") ms_server_args.append("--enforce-eager")
......
...@@ -140,6 +140,7 @@ class ModelConfig: ...@@ -140,6 +140,7 @@ class ModelConfig:
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None, served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -172,6 +173,7 @@ class ModelConfig: ...@@ -172,6 +173,7 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision) self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
# Choose a default enforce_eager value if the user did not specify # Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None) # a value (enforce_eager is None)
...@@ -326,6 +328,49 @@ class ModelConfig: ...@@ -326,6 +328,49 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len) self.max_model_len)
def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
# Nothing to check
return
if parallel_config.pipeline_parallel_size > 1:
logger.warning("Async output processing can not be enabled "
"with pipeline parallel")
self.use_async_output_proc = False
return
if device_config.device_type != "cuda":
logger.warning(
"Async output processing is only supported for CUDA."
" Disabling it for other platforms.")
self.use_async_output_proc = False
return
if envs.VLLM_USE_RAY_SPMD_WORKER:
logger.warning(
"Async output processing can not be enabled with ray spmd")
self.use_async_output_proc = False
return
if self.enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
self.use_async_output_proc = not self.enforce_eager
return
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.embedding_mode:
self.use_async_output_proc = False
if speculative_config:
logger.warning("Async output processing is not supported with"
" speculative decoding currently.")
self.use_async_output_proc = False
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
...@@ -358,6 +403,11 @@ class ModelConfig: ...@@ -358,6 +403,11 @@ class ModelConfig:
"fallback to the eager mode.") "fallback to the eager mode.")
self.enforce_eager = True self.enforce_eager = True
if pipeline_parallel_size > 1 and self.use_async_output_proc:
logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False
def get_hf_config_sliding_window(self) -> Optional[int]: def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.""" """Get the sliding window size, or None if disabled."""
...@@ -1769,6 +1819,9 @@ class EngineConfig: ...@@ -1769,6 +1819,9 @@ class EngineConfig:
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
""" """
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
......
...@@ -4,7 +4,8 @@ import random ...@@ -4,7 +4,8 @@ import random
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
...@@ -299,6 +300,7 @@ class Scheduler: ...@@ -299,6 +300,7 @@ class Scheduler:
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1, pipeline_parallel_size: int = 1,
output_proc_callback_fn: Optional[Callable] = None,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -364,10 +366,36 @@ class Scheduler: ...@@ -364,10 +366,36 @@ class Scheduler:
self.num_cumulative_preemption: int = 0 self.num_cumulative_preemption: int = 0
# Used to cache python objects # Used to cache python objects
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( self._seq_group_metadata_cache: List[PyObjectCache] = []
scheduler_running_outputs_builder) self._scheduler_running_outputs_cache: List[PyObjectCache] = []
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( self._scheduled_seq_group_cache: List[PyObjectCache] = []
scheduled_seq_group_builder)
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self.output_proc_callback_fn = output_proc_callback_fn
self.use_async_output_proc = self.output_proc_callback_fn is not None
self.num_cache_iters = 2 if self.use_async_output_proc else 1
self.cache_id = 0
for i in range(self.num_cache_iters):
self._seq_group_metadata_cache.append(
PyObjectCache(seq_group_metadata_builder))
self._scheduler_running_outputs_cache.append(
PyObjectCache(scheduler_running_outputs_builder))
self._scheduled_seq_group_cache.append(
PyObjectCache(scheduled_seq_group_builder))
# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []
@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
@property @property
def lora_enabled(self) -> bool: def lora_enabled(self) -> bool:
...@@ -483,7 +511,7 @@ class Scheduler: ...@@ -483,7 +511,7 @@ class Scheduler:
SchedulerRunningOutputs. SchedulerRunningOutputs.
""" """
ret: SchedulerRunningOutputs = \ ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache.get_object() self._scheduler_running_outputs_cache[self.cache_id].get_object()
ret.blocks_to_swap_out.clear() ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear() ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear() ret.decode_seq_groups.clear()
...@@ -510,8 +538,12 @@ class Scheduler: ...@@ -510,8 +538,12 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot # NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state. # to keep all the sequence groups in the RUNNING state.
running_queue = self.running # Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue: while running_queue:
seq_group = running_queue[0] seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens( num_running_tokens = self._get_num_new_tokens(
...@@ -521,6 +553,28 @@ class Scheduler: ...@@ -521,6 +553,28 @@ class Scheduler:
break break
running_queue.popleft() running_queue.popleft()
# With async postprocessor, an extra decode run is done
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if self.use_async_output_proc and seq_group.seqs[0].get_len(
) > self.scheduler_config.max_model_len:
self._async_stopped.append(seq_group)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback_fn is not None
self.output_proc_callback_fn(is_async=True)
self.running = tmp
while not self._can_append_slots(seq_group): while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id, budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens) num_running_tokens)
...@@ -556,7 +610,7 @@ class Scheduler: ...@@ -556,7 +610,7 @@ class Scheduler:
is_prefill = seq_group.is_prefill() is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \ scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache.get_object() self._scheduled_seq_group_cache[self.cache_id].get_object()
scheduled_seq_group.seq_group = seq_group scheduled_seq_group.seq_group = seq_group
if is_prefill: if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens scheduled_seq_group.token_chunk_size = num_running_tokens
...@@ -579,8 +633,8 @@ class Scheduler: ...@@ -579,8 +633,8 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id) curr_loras.add(seq_group.lora_int_id)
self._scheduler_running_outputs_cache.reset() self._scheduler_running_outputs_cache[self.next_cache_id].reset()
self._scheduled_seq_group_cache.reset() self._scheduled_seq_group_cache[self.next_cache_id].reset()
return ret return ret
...@@ -1031,17 +1085,31 @@ class Scheduler: ...@@ -1031,17 +1085,31 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
) )
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
no_beam_search = (seq_group.sampling_params.best_of == 1
and not seq_group.sampling_params.use_beam_search)
return no_beam_search
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.perf_counter() scheduler_start_time = time.perf_counter()
scheduler_outputs = self._schedule() scheduler_outputs = self._schedule()
now = time.time() now = time.time()
if not self.cache_config.enable_prefix_caching: if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = [] common_computed_block_nums = []
# TODO: Combine multi-step and async postprocessor
allow_async_output_proc: bool = (
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, scheduled_seq_group in enumerate( for i, scheduled_seq_group in enumerate(
...@@ -1050,6 +1118,11 @@ class Scheduler: ...@@ -1050,6 +1118,11 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now) seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache[
self.cache_id].get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData # seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers # seq_id -> physical block numbers
...@@ -1139,6 +1212,10 @@ class Scheduler: ...@@ -1139,6 +1212,10 @@ class Scheduler:
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
if allow_async_output_proc:
allow_async_output_proc = self._allow_async_output_proc(
seq_group)
# Now that the batch has been created, we can assume all blocks in the # Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation. # batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution # This is because the engine assumes that a failure in model execution
...@@ -1147,6 +1224,8 @@ class Scheduler: ...@@ -1147,6 +1224,8 @@ class Scheduler:
self.block_manager.mark_blocks_as_computed( self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group) scheduled_seq_group.seq_group)
self._seq_group_metadata_cache[self.next_cache_id].reset()
scheduler_time = time.perf_counter() - scheduler_start_time scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently # Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant # running. This will help estimate if the scheduler is a significant
...@@ -1158,7 +1237,12 @@ class Scheduler: ...@@ -1158,7 +1237,12 @@ class Scheduler:
else: else:
seq_group.metrics.scheduler_time = scheduler_time seq_group.metrics.scheduler_time = scheduler_time
return seq_group_metadata_list, scheduler_outputs # Move to next cache (if exists)
self.cache_id = self.next_cache_id
# Return results
return (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq) self.block_manager.fork(parent_seq, child_seq)
...@@ -1167,6 +1251,12 @@ class Scheduler: ...@@ -1167,6 +1251,12 @@ class Scheduler:
"""Free a sequence from a block table.""" """Free a sequence from a block table."""
self.block_manager.free(seq) self.block_manager.free(seq)
def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
"""Free finished seqs in a sequence group."""
for seq in seq_group.get_seqs():
if seq.is_finished():
self.free_seq(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque() remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running: for seq_group in self.running:
...@@ -1179,8 +1269,24 @@ class Scheduler: ...@@ -1179,8 +1269,24 @@ class Scheduler:
self._finished_requests_ids.append(seq_group.request_id) self._finished_requests_ids.append(seq_group.request_id)
else: else:
remaining.append(seq_group) remaining.append(seq_group)
# Free finished seqs
self._free_finished_seqs(seq_group)
self.running = remaining self.running = remaining
# Handle async stopped sequence groups
# (ones that reached max model len)
if self._async_stopped:
for seq_group in self._async_stopped:
self._free_seq_group_cross_attn_blocks(seq_group)
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
self._free_finished_seqs(seq_group)
self._async_stopped.clear()
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group) self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
......
...@@ -147,6 +147,7 @@ class EngineArgs: ...@@ -147,6 +147,7 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -733,6 +734,12 @@ class EngineArgs: ...@@ -733,6 +734,12 @@ class EngineArgs:
"modules. This involves use of possibly costly and or blocking " "modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.") "operations and hence might have a performance impact.")
parser.add_argument(
'--disable-async-output-proc',
action='store_true',
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
return parser return parser
@classmethod @classmethod
...@@ -792,6 +799,7 @@ class EngineArgs: ...@@ -792,6 +799,7 @@ class EngineArgs:
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
) )
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
......
...@@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine):
cached_outputs = self.cached_scheduler_outputs[virtual_engine] cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# skip the scheduler if there are any remaining steps in the seq groups. # skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current # This ensures that the scheduler is only called again when the current
# batch has completed. # batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list): if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[ (seq_group_metadata_list, scheduler_outputs,
virtual_engine].schedule() allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
if (self.scheduler_config.is_multi_step if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0): and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have # cache the scheduler outputs for the next iteration if we have
# lookahead slots # lookahead slots
self._cache_scheduler_outputs_for_multi_step( self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs) virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
assert scheduler_outputs is not None assert scheduler_outputs is not None
assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[ finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids() virtual_engine].get_and_reset_finished_requests_ids()
...@@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine):
# We use ExecuteModelRequest to pass the last sampled_token_ids # We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input. # to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids) last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
# Execute the model. # Execute the model.
output = await self.model_executor.execute_model_async( output = await self.model_executor.execute_model_async(
execute_model_req) execute_model_req)
...@@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output) self._update_cached_scheduler_output(virtual_engine, output)
else: else:
if len(self.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
output = [] output = []
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
...@@ -337,11 +358,21 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -337,11 +358,21 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[ self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState() virtual_engine] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups, # Cache results in engine
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) self.output_queue.append(
else: (output, seq_group_metadata_list, scheduler_outputs))
request_outputs = []
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
...@@ -349,7 +380,10 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -349,7 +380,10 @@ class _AsyncLLMEngine(LLMEngine):
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)
return request_outputs else:
self.request_outputs = []
return self.request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None: async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop.""" """Stop the remote worker execution loop."""
......
This diff is collapsed.
...@@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC):
# Importing here to avoid cycle. # Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import ( from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor) SingleStepOutputProcessor)
return SingleStepOutputProcessor( return SingleStepOutputProcessor(scheduler_config, detokenizer,
scheduler_config, scheduler, seq_counter,
detokenizer, stop_checker)
scheduler,
seq_counter,
stop_checker,
)
else: else:
# Importing here to avoid cycle. # Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import ( from vllm.engine.output_processor.multi_step import (
...@@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC):
@abstractmethod @abstractmethod
def process_outputs(self, sequence_group: SequenceGroup, def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None: outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Process new token ids for the sequence group. Handles logic such as """Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the detokenization, stop checking, and freeing/forking sequences in the
scheduler. scheduler.
......
...@@ -57,17 +57,28 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -57,17 +57,28 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Prompt logprob is not supported by multi step workers. " "Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).") "(e.g., speculative decode uses multi step workers).")
def process_outputs(self, sequence_group: SequenceGroup, def process_outputs(self,
outputs: List[SequenceGroupOutput]) -> None: sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None:
"""Append new tokens in the outputs to sequences in the sequence group. """Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than This only supports sequence groups of size 1. It supports greater than
one new token per sequence. one new token per sequence.
This applies logic like stop condition checking and detokenization, This applies logic like stop condition checking and detokenization.
including freeing finished sequences. It also handles cases where there It also handles cases where there are tokens emitted after
are tokens emitted after the EOS token. the EOS token.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
""" """
# TODO: Add support for async if necessary
assert not is_async
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
assert seqs, "expected running sequences" assert seqs, "expected running sequences"
...@@ -138,7 +149,3 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -138,7 +149,3 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
) )
if seq.is_finished(): if seq.is_finished():
break break
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
...@@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
that is currently difficult to schedule multiple steps ahead of time. that is currently difficult to schedule multiple steps ahead of time.
""" """
def __init__( def __init__(self, scheduler_config: SchedulerConfig,
self, detokenizer: Detokenizer, scheduler: List[Scheduler],
scheduler_config: SchedulerConfig, seq_counter: Counter, stop_checker: StopChecker):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.detokenizer = detokenizer self.detokenizer = detokenizer
self.scheduler = scheduler self.scheduler = scheduler
...@@ -44,16 +39,24 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -44,16 +39,24 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self.stop_checker = stop_checker self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup, def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None: outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any """Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones. surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions. as finished if they meet stop conditions.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
""" """
assert (len(outputs) == 1 assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step" ), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0]) return self._process_sequence_group_outputs(sequence_group, outputs[0],
is_async)
def process_prompt_logprob(self, seq_group: SequenceGroup, def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None: outputs: List[SequenceGroupOutput]) -> None:
...@@ -80,13 +83,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -80,13 +83,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
seq_group.prompt_logprobs.extend(prompt_logprobs) seq_group.prompt_logprobs.extend(prompt_logprobs)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None: outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
if sampling_params.n == 1 and not sampling_params.use_beam_search: if sampling_params.n == 1 and not sampling_params.use_beam_search:
# only have one output sample # only have one output sample
sample = outputs.samples[0] sample = outputs.samples[0]
# only have one sequence # only have one sequence
seq = seq_group.seqs[0] seq = seq_group.seqs[0]
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer: if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace( new_char_count = self.detokenizer.decode_sequence_inplace(
...@@ -104,6 +109,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -104,6 +109,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler.free_seq(seq) scheduler.free_seq(seq)
return return
# TODO: Add support for async for beam search
assert not is_async
# Process samples # Process samples
samples = outputs.samples samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
......
...@@ -129,6 +129,7 @@ class LLM: ...@@ -129,6 +129,7 @@ class LLM:
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
''' '''
...@@ -170,6 +171,7 @@ class LLM: ...@@ -170,6 +171,7 @@ class LLM:
max_context_len_to_capture=max_context_len_to_capture, max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args( self.llm_engine = LLMEngine.from_engine_args(
...@@ -603,7 +605,6 @@ class LLM: ...@@ -603,7 +605,6 @@ class LLM:
inputs = [inputs] inputs = [inputs]
num_requests = len(inputs) num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests: if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params " raise ValueError("The lengths of prompts and params "
"must be the same.") "must be the same.")
...@@ -678,6 +679,10 @@ class LLM: ...@@ -678,6 +679,10 @@ class LLM:
postfix=(f"est. speed input: {0:.2f} toks/s, " postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"), f"output: {0:.2f} toks/s"),
) )
# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True
# Run the engine. # Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0 total_in_toks = 0
...@@ -700,6 +705,10 @@ class LLM: ...@@ -700,6 +705,10 @@ class LLM:
f"est. speed input: {in_spd:.2f} toks/s, " f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s") f"output: {out_spd:.2f} toks/s")
pbar.update(1) pbar.update(1)
# Restore original behavior
self.llm_engine.step_return_finished_only = False
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
# Sort the outputs by request ID. # Sort the outputs by request ID.
......
...@@ -65,7 +65,8 @@ class DistributedGPUExecutor(GPUExecutor): ...@@ -65,7 +65,8 @@ class DistributedGPUExecutor(GPUExecutor):
def execute_model( def execute_model(
self, self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers( self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop", "start_worker_execution_loop",
...@@ -188,7 +189,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): ...@@ -188,7 +189,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@abstractmethod @abstractmethod
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker. """Execute the model asynchronously in the driver worker.
......
...@@ -176,5 +176,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): ...@@ -176,5 +176,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> List[Union[SamplerOutput, PoolerOutput]]: ) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, ) )(execute_model_req=execute_model_req)
return output return output
...@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod ...@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
Tuple, Union, cast) Optional, Set, Tuple, Union, cast)
import msgspec import msgspec
import torch import torch
...@@ -474,11 +474,8 @@ class Sequence: ...@@ -474,11 +474,8 @@ class Sequence:
"""Reset the sequence states for recomputation.""" """Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute() self.data.reset_state_for_recompute()
def append_token_id( def append_token_id(self, token_id: int, logprobs: Dict[int,
self, Logprob]) -> None:
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
assert token_id in logprobs assert token_id in logprobs
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob) self.data.append_token_id(token_id, logprobs[token_id].logprob)
...@@ -1293,6 +1290,8 @@ class ExecuteModelRequest( ...@@ -1293,6 +1290,8 @@ class ExecuteModelRequest(
finished_requests_ids: List[str] = msgspec.field(default_factory=list) finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding. # The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None last_sampled_token_ids: Optional[torch.Tensor] = None
# Async postprocessor
output_proc_callback_fn: Optional[Callable] = None
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
...@@ -1338,4 +1337,5 @@ class ExecuteModelRequest( ...@@ -1338,4 +1337,5 @@ class ExecuteModelRequest(
num_steps=self.num_steps, num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids, finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone() last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None) if self.last_sampled_token_ids is not None else None,
output_proc_callback_fn=self.output_proc_callback_fn)
...@@ -6,8 +6,8 @@ import time ...@@ -6,8 +6,8 @@ import time
import warnings import warnings
import weakref import weakref
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
TypeVar, Union) Tuple, Type, TypeVar, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -90,6 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -90,6 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0 virtual_engine: int = 0
output_proc_callback_fn: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
...@@ -1327,7 +1328,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1327,7 +1328,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including """Prepare the model input based on a given sequence group, including
metadata for the sampling step. metadata for the sampling step.
...@@ -1451,6 +1452,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1451,6 +1452,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []
if model_input.output_proc_callback_fn is not None:
model_input.output_proc_callback_fn(is_async=True)
# Sample the next token. # Sample the next token.
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.model.sample(
logits=logits, logits=logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment