Unverified Commit 3bcb8c75 authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[Core] Reduce TTFT with concurrent partial prefills (#10235)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
Signed-off-by: default avatarPrashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: default avatarPrashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: default avatarCody Yu <hao.yu.cody@gmail.com>
parent 5e5c8e09
...@@ -8,7 +8,6 @@ prefill requests are chunked. ...@@ -8,7 +8,6 @@ prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`. Run `pytest tests/models/test_chunked_prefill.py`.
""" """
import os import os
from contextlib import nullcontext
import pytest import pytest
...@@ -233,7 +232,6 @@ def test_with_prefix_caching( ...@@ -233,7 +232,6 @@ def test_with_prefix_caching(
max_num_batched_tokens = max_num_seqs = chunk_size max_num_batched_tokens = max_num_seqs = chunk_size
outputs = {} # type: ignore outputs = {} # type: ignore
check_result = True
for enable in (True, False): for enable in (True, False):
with vllm_runner( with vllm_runner(
model, model,
...@@ -245,19 +243,11 @@ def test_with_prefix_caching( ...@@ -245,19 +243,11 @@ def test_with_prefix_caching(
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
) as vllm_model: ) as vllm_model:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail = chunk_size % 16 != 0 and enable
check_result &= not should_fail
outputs[enable] = [] outputs[enable] = []
# Send the request one-by-one to ensure the cache is populated.
with pytest.raises(ValueError) if should_fail else nullcontext():
for prompt in full_prompts: for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt], outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens) max_tokens)
# Check results only if we did not expect a failure.
if check_result:
check_outputs_equal( check_outputs_equal(
outputs_0_lst=outputs[False], outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True], outputs_1_lst=outputs[True],
......
...@@ -7,6 +7,9 @@ import pytest # noqa ...@@ -7,6 +7,9 @@ import pytest # noqa
from vllm.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, SequenceGroup from vllm.sequence import Logprob, SequenceGroup
from .utils import create_dummy_prompt from .utils import create_dummy_prompt
...@@ -16,7 +19,7 @@ def get_sequence_groups(scheduler_output): ...@@ -16,7 +19,7 @@ def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups] return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def append_new_token(seq_group, token_id: int): def append_new_token(seq_group: SequenceGroup, token_id: int):
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)}) seq.append_token_id(token_id, {token_id: Logprob(token_id)})
...@@ -123,6 +126,232 @@ def test_chunk(): ...@@ -123,6 +126,232 @@ def test_chunk():
assert out.num_batched_tokens == 57 assert out.num_batched_tokens == 57
def test_concurrent_chunking():
"""Verify prefills are chunked properly when
--max-num-partial-prefills is > 1"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# Verify both requests are chunked with half of max_num_batched_tokens each
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 32
assert seq_group_meta[1].token_chunk_size == 32
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
# After one iteration, both should have 60 - 32 = 28 tokens left to prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 28
assert seq_group_meta[1].token_chunk_size == 28
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 56
def test_concurrent_chunking_large_requests():
"""Verify large prefill requests are run one at a time"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
cache_config.num_gpu_blocks = 3200
scheduler = Scheduler(scheduler_config, cache_config, None)
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i),
prompt_length=1200, # Very large prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
# Verify only a single request is chunked, and it gets all 64 tokens
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 1
assert seq_group_meta[0].token_chunk_size == 64
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64
def test_short_prompts_jump_long_prompts_in_queue():
"""Verify large prefill requests are punted behind smaller ones if
another large prefill request is already running"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
cache_config.num_gpu_blocks = 3200
scheduler = Scheduler(scheduler_config, cache_config, None)
long_seqs: List[SequenceGroup] = []
short_seqs: List[SequenceGroup] = []
# Add 2 large seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i),
prompt_length=1200, # Very large prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
long_seqs.append(seq_group)
assert seq_group.is_prefill()
# Add 2 small seq groups behind them
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i + 2),
prompt_length=40, # Very small prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
short_seqs.append(seq_group)
assert seq_group.is_prefill()
# Verify one large req and 1 small req chunked
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens
# all 4 are prefilling
assert long_seqs[0].is_prefill()
assert long_seqs[1].is_prefill()
assert short_seqs[0].is_prefill()
assert short_seqs[1].is_prefill()
# First short and first long sequences have been scheduled
assert long_seqs[0].first_seq.get_num_computed_tokens() == 32
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
assert short_seqs[0].first_seq.get_num_computed_tokens() == 32
assert short_seqs[1].first_seq.get_num_computed_tokens() == 0
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
# in the second iteration,
# the first small request had only 8 tokens left
# so it went to decode
# The other small req is scheduled
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# the new small req got 64 - (32+8) tokens
assert seq_group_meta[0].token_chunk_size == 24
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32
# the other small request had only 8 tokens left
assert seq_group_meta[2].token_chunk_size == 8 # 40-32
# The first small request got to decode now
assert long_seqs[0].is_prefill()
assert long_seqs[1].is_prefill()
assert not short_seqs[0].is_prefill()
assert short_seqs[1].is_prefill()
# Both small requests have started in front of the second long request
assert long_seqs[0].first_seq.get_num_computed_tokens() == 64
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
assert short_seqs[0].first_seq.get_num_computed_tokens() == 40
assert short_seqs[1].first_seq.get_num_computed_tokens() == 24
assert out.num_prefill_groups == 3
assert out.num_batched_tokens == 64
# the first small seq group has a new token appended.
append_new_token(short_seqs[0], 1)
# in the third iteration,
# the first small request is already decoding
# the second small request only has 16 tokens left and will enter decoding
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32
# small req finished prefilling 40-24=16 tokens
assert seq_group_meta[1].token_chunk_size == 16
assert seq_group_meta[2].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 49 # (32+16+1 decode)
# both small requests have now reached decode
assert long_seqs[0].is_prefill()
assert long_seqs[1].is_prefill()
assert not short_seqs[0].is_prefill()
assert not short_seqs[1].is_prefill()
assert long_seqs[0].first_seq.get_num_computed_tokens() == 96
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
assert short_seqs[0].first_seq.get_num_computed_tokens() == 41
assert short_seqs[1].first_seq.get_num_computed_tokens() == 40
# both the small seq groups have a new token appended
append_new_token(short_seqs[0], 1)
append_new_token(short_seqs[1], 1)
# in the fourth iteration, both small requests are decoding
# so large request gets all the budget
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# large req gets 62 tokens (minus 2 for decode)
assert seq_group_meta[0].token_chunk_size == 62
assert seq_group_meta[1].token_chunk_size == 1 # decode
assert seq_group_meta[2].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64
assert long_seqs[0].first_seq.get_num_computed_tokens() == 158
# assert long_seqs[0].is_prefill()
# assert long_seqs[1].is_prefill()
# assert not short_seqs[0].is_prefill()
# assert not short_seqs[1].is_prefill()
# # both the small seq groups have a new token appended
# append_new_token(short_seqs[0], 1)
# append_new_token(short_seqs[1], 1)
# # in the fifth iteration, large request gets all the budget
# # while both small requests are decoding
# seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# assert seq_group_meta[0].token_chunk_size == 62
# assert seq_group_meta[1].token_chunk_size == 1 # decode
# assert seq_group_meta[2].token_chunk_size == 1 # decode
# assert out.num_prefill_groups == 1
# assert out.num_batched_tokens == 64
def test_complex(): def test_complex():
block_size = 4 block_size = 4
max_seqs = 60 max_seqs = 60
...@@ -508,7 +737,7 @@ def test_chunked_prefill_max_seqs(): ...@@ -508,7 +737,7 @@ def test_chunked_prefill_max_seqs():
assert not running[1].is_prefill() assert not running[1].is_prefill()
def test_perfix_caching(): def test_prefix_caching():
"""Verify allocating full blocks when prefix caching is enabled.""" """Verify allocating full blocks when prefix caching is enabled."""
block_size = 4 block_size = 4
max_seqs = 10 max_seqs = 10
...@@ -548,3 +777,86 @@ def test_perfix_caching(): ...@@ -548,3 +777,86 @@ def test_perfix_caching():
assert seq_group_meta[1].token_chunk_size == 12 assert seq_group_meta[1].token_chunk_size == 12
assert out.num_prefill_groups == 2 assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 62 assert out.num_batched_tokens == 62
def test_prefix_caching_with_concurrent_partial_prefills():
"""Verify allocating full blocks when prefix caching is enabled with
--max-num-partial-prefills > 1."""
block_size = 4
max_seqs = 10
max_model_len = 8000
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
scheduler_config = SchedulerConfig("generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2)
cache_config = CacheConfig(block_size,
1.0,
1,
"auto",
enable_prefix_caching=True)
cache_config.num_cpu_blocks = 0
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
block_size=block_size,
prompt_length=50)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# To partially prefill both sequences, both can chunk up to 30 tokens
# But the next lowest multiple of the block size (4) is 28
assert seq_group_meta[0].token_chunk_size == 28
assert seq_group_meta[1].token_chunk_size == 28
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 56
# On the next iteration, both sequences should finish prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# Both sequences have 50 - 28 = 22 tokens left to prefill.
# This is not a multiple of the block size, but we don't care since we don't
# cache the final partial block of prefix sequences
assert seq_group_meta[0].token_chunk_size == 22
assert seq_group_meta[1].token_chunk_size == 22
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 44
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
def test_chunked_prefill_with_actual_engine(model: str,
max_num_partial_prefills: int):
"""Make sure the model can actually sample with concurrent
partial prefills
"""
prompt = "hello" * 40
engine_args = EngineArgs(
model=model,
max_num_partial_prefills=max_num_partial_prefills,
max_num_batched_tokens=40,
max_num_seqs=8,
enable_chunked_prefill=True,
gpu_memory_utilization=0.8,
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(temperature=0)
for req_num in range(max_num_partial_prefills):
engine.add_request(f"{req_num}", prompt, sampling_params)
# first step
request_outputs = engine.step()
# means all are prefilling
assert len(request_outputs) == 0
assert len(engine.scheduler[0].running) == max_num_partial_prefills
...@@ -1430,6 +1430,17 @@ class SchedulerConfig: ...@@ -1430,6 +1430,17 @@ class SchedulerConfig:
# Maximum length of a sequence (including prompt and generated text). # Maximum length of a sequence (including prompt and generated text).
max_model_len: int = 8192 max_model_len: int = 8192
# Maximum number of sequences that can be partially prefilled concurrently
max_num_partial_prefills: int = 1
# Maximum number of "very long prompt" sequences that can be prefilled
# concurrently (long is defined by long_prefill_threshold)
max_long_partial_prefills: int = 1
# calculate context length that determines which sequences are
# considered "long"
long_prefill_token_threshold: int = 0
# The number of slots to allocate per sequence per # The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative # step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be # decoding to store KV activations of tokens which may or may not be
...@@ -1537,6 +1548,18 @@ class SchedulerConfig: ...@@ -1537,6 +1548,18 @@ class SchedulerConfig:
self.max_num_batched_tokens) self.max_num_batched_tokens)
self.chunked_prefill_enabled = self.enable_chunked_prefill self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len *
0.04)
logger.info(
"Concurrent partial prefills enabled with "
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
"long_prefill_token_threshold=%d",
self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold)
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -1568,6 +1591,29 @@ class SchedulerConfig: ...@@ -1568,6 +1591,29 @@ class SchedulerConfig:
f"({self.num_scheduler_steps}) must be greater than or " f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.") "equal to 1.")
if self.max_num_partial_prefills < 1:
raise ValueError(
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
"must be greater than or equal to 1.")
elif self.max_num_partial_prefills > 1:
if not self.chunked_prefill_enabled:
raise ValueError("Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1.")
if self.long_prefill_token_threshold > self.max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({self.max_model_len}).")
if (self.max_long_partial_prefills
< 1) or (self.max_long_partial_prefills
> self.max_num_partial_prefills):
raise ValueError(
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
@property @property
def is_multi_step(self) -> bool: def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1 return self.num_scheduler_steps > 1
......
This diff is collapsed.
...@@ -120,6 +120,9 @@ class EngineArgs: ...@@ -120,6 +120,9 @@ class EngineArgs:
cpu_offload_gb: float = 0 # GiB cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_partial_prefills: Optional[int] = 1
max_long_partial_prefills: Optional[int] = 1
long_prefill_token_threshold: Optional[int] = 0
max_num_seqs: Optional[int] = None max_num_seqs: Optional[int] = None
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False disable_log_stats: bool = False
...@@ -515,6 +518,31 @@ class EngineArgs: ...@@ -515,6 +518,31 @@ class EngineArgs:
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per ' help='Maximum number of batched tokens per '
'iteration.') 'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills."
"Defaults to 1",
)
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency. Defaults to 1.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens. Defaults to 4%% of "
"the model's context length.",
)
parser.add_argument('--max-num-seqs', parser.add_argument('--max-num-seqs',
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
...@@ -1244,7 +1272,11 @@ class EngineArgs: ...@@ -1244,7 +1272,11 @@ class EngineArgs:
multi_step_stream_outputs=self.multi_step_stream_outputs, multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray), and parallel_config.use_ray),
policy=self.scheduling_policy) policy=self.scheduling_policy,
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
)
lora_config = LoRAConfig( lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias, bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
......
...@@ -958,7 +958,9 @@ def get_logprobs( ...@@ -958,7 +958,9 @@ def get_logprobs(
if len(query_indices) == 0: if len(query_indices) == 0:
empty_sampled_logprob: SampleLogprobs = [] empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob: Optional[PromptLogprobs] = None empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob] num_seq_groups = len(sampling_metadata.seq_groups)
return [empty_prompt_logprob
] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups
selected_logprobs, ranks = None, None selected_logprobs, ranks = None, None
top_logprobs, top_token_ids = None, None top_logprobs, top_token_ids = None, None
...@@ -1225,6 +1227,10 @@ def _build_sampler_output( ...@@ -1225,6 +1227,10 @@ def _build_sampler_output(
assert sample_logprobs is not None assert sample_logprobs is not None
assert not isinstance(maybe_deferred_sample_results, assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType) SampleResultArgsType)
assert len(sampling_metadata.seq_groups) \
== len(maybe_deferred_sample_results) \
== len(prompt_logprobs) \
== len(sample_logprobs)
deferred_sample_results_args = None deferred_sample_results_args = None
for (seq_group, sample_result, group_prompt_logprobs, for (seq_group, sample_result, group_prompt_logprobs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment