Unverified Commit 18de8834 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Chunked Prefill][4/n] Chunked prefill scheduler. (#3853)

parent 1d7c940d
This diff is collapsed.
This diff is collapsed.
import time
from typing import Optional
import pytest
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
from vllm import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import (SamplerOutput, Sequence, SequenceData,
SequenceGroup, SequenceGroupOutput, SequenceOutput)
def create_dummy_prompt(
request_id: str,
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> SequenceGroup:
if not block_size:
block_size = prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(
request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)
return seq_group
@pytest.fixture
......@@ -67,6 +96,29 @@ def test_sequence_data_prefill():
# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens()
seq_data.reset_state_for_recompute()
assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0
def test_sequence_group_stage():
seq_group = create_dummy_prompt("1", 12)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(6)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
seqs = seq_group.get_seqs()
assert len(seqs) == 1
seqs[0].data.append_token_id(1, logprob=0.0)
for seq in seq_group.get_seqs():
seq.reset_state_for_recompute()
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(7)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
......@@ -576,7 +576,8 @@ class SchedulerConfig:
self._verify_args()
def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
if (self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
......
......@@ -38,9 +38,7 @@ class FCFS(Policy):
class PolicyFactory:
_POLICY_REGISTRY = {
'fcfs': FCFS,
}
_POLICY_REGISTRY = {'fcfs': FCFS}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
......
This diff is collapsed.
......@@ -607,11 +607,10 @@ class LLMEngine:
now = time.time()
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.update_num_computed_tokens(token_chunk_size)
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
......
......@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum):
return finish_reason
class SequenceStage(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
@dataclass
class RequestMetrics:
"""Metrics associated with a request.
......@@ -115,6 +120,7 @@ class SequenceData:
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL
def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
......@@ -136,16 +142,22 @@ class SequenceData:
"""Return the number of prefill tokens that are already computed."""
return self._num_computed_tokens
def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int:
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
def reset_num_computed_tokens(self) -> None:
def reset_state_for_recompute(self) -> None:
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted).
"""
self._num_computed_tokens = 0
self._stage = SequenceStage.PREFILL
def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefil tokens that are not computed."""
......@@ -165,6 +177,10 @@ class SequenceData:
def get_output_token_ids(self) -> int:
return self.output_token_ids
@property
def stage(self) -> SequenceStage:
return self._stage
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self.prompt_token_ids}, "
......@@ -234,7 +250,7 @@ class Sequence:
def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation."""
self.data.reset_num_computed_tokens()
self.data.reset_state_for_recompute()
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
......@@ -320,6 +336,23 @@ class Sequence:
new_seq.seq_id = new_seq_id
return new_seq
def get_num_new_tokens(self) -> int:
"""Get the number of new tokens to be computed.
Args:
remainig_token_budget: The remaining token budgets.
Returns:
The new number of tokens to be computed. I.e., 1 for decode, prompt
size for prefill. If there's not enough remainig_token_budget, it
can return the chunked number of new tokens.
"""
if self.data.stage == SequenceStage.DECODE:
return 1
return self.data.get_num_uncomputed_tokens()
def is_prefill(self) -> bool:
return self.data.stage == SequenceStage.PREFILL
def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, "
......@@ -461,14 +494,14 @@ class SequenceGroup:
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
for seq in self.seqs_dict.values():
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int:
# All sequences in the group should have the same prompt, so the
# number of unfinished prefill tokens are the same across all
# sequences.
return list(
self.seqs_dict.values())[0].data.get_num_uncomputed_tokens()
num_uncomputed_tokens = 0
for seq in self.get_seqs():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
......@@ -497,6 +530,10 @@ class SequenceGroup:
def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.get_seqs())
def is_prefill(self) -> bool:
# Every sequences should be in the same stage.
return self.get_seqs()[0].is_prefill()
def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, "
......@@ -513,8 +550,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
token_chunk_size: The number of tokens to be processed. None if
chunking is not required.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
multi_modal_data: Multi modal data.
......
......@@ -222,7 +222,6 @@ class ModelRunner:
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prefill_end)))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment