Unverified Commit 18de8834 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Chunked Prefill][4/n] Chunked prefill scheduler. (#3853)

parent 1d7c940d
...@@ -11,4 +11,4 @@ uvicorn[standard] ...@@ -11,4 +11,4 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server. pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0 prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34 # Requires torch >= 2.1.0 outlines == 0.0.34 # Requires torch >= 2.1.0
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
import time
from typing import Optional
import pytest import pytest
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, from vllm import SamplingParams
SequenceOutput) from vllm.lora.request import LoRARequest
from vllm.sequence import (SamplerOutput, Sequence, SequenceData,
SequenceGroup, SequenceGroupOutput, SequenceOutput)
def create_dummy_prompt(
request_id: str,
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> SequenceGroup:
if not block_size:
block_size = prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(
request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)
return seq_group
@pytest.fixture @pytest.fixture
...@@ -67,6 +96,29 @@ def test_sequence_data_prefill(): ...@@ -67,6 +96,29 @@ def test_sequence_data_prefill():
# append tokens and reset, simulating recompute # append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0) seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens() seq_data.reset_state_for_recompute()
assert seq_data.get_num_uncomputed_tokens() == 5 assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0 assert seq_data.get_num_computed_tokens() == 0
def test_sequence_group_stage():
seq_group = create_dummy_prompt("1", 12)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(6)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
seqs = seq_group.get_seqs()
assert len(seqs) == 1
seqs[0].data.append_token_id(1, logprob=0.0)
for seq in seq_group.get_seqs():
seq.reset_state_for_recompute()
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(7)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
...@@ -576,7 +576,8 @@ class SchedulerConfig: ...@@ -576,7 +576,8 @@ class SchedulerConfig:
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len: if (self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled):
raise ValueError( raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). " f"smaller than max_model_len ({self.max_model_len}). "
......
...@@ -38,9 +38,7 @@ class FCFS(Policy): ...@@ -38,9 +38,7 @@ class FCFS(Policy):
class PolicyFactory: class PolicyFactory:
_POLICY_REGISTRY = { _POLICY_REGISTRY = {'fcfs': FCFS}
'fcfs': FCFS,
}
@classmethod @classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy: def get_policy(cls, policy_name: str, **kwargs) -> Policy:
......
This diff is collapsed.
...@@ -607,11 +607,10 @@ class LLMEngine: ...@@ -607,11 +607,10 @@ class LLMEngine:
now = time.time() now = time.time()
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.update_num_computed_tokens(
seq_group.update_num_computed_tokens(token_chunk_size) scheduled_seq_group.token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs) self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
......
...@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum): ...@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum):
return finish_reason return finish_reason
class SequenceStage(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
@dataclass @dataclass
class RequestMetrics: class RequestMetrics:
"""Metrics associated with a request. """Metrics associated with a request.
...@@ -115,6 +120,7 @@ class SequenceData: ...@@ -115,6 +120,7 @@ class SequenceData:
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model). # The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0 self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id) self.output_token_ids.append(token_id)
...@@ -136,16 +142,22 @@ class SequenceData: ...@@ -136,16 +142,22 @@ class SequenceData:
"""Return the number of prefill tokens that are already computed.""" """Return the number of prefill tokens that are already computed."""
return self._num_computed_tokens return self._num_computed_tokens
def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int: def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far.""" """Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
def reset_num_computed_tokens(self) -> None: def reset_state_for_recompute(self) -> None:
"""Reset the number of computed tokens from this sequence. It is """Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted). the beginning again (e.g., sequence is preempted).
""" """
self._num_computed_tokens = 0 self._num_computed_tokens = 0
self._stage = SequenceStage.PREFILL
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefil tokens that are not computed.""" """Return the number of prefil tokens that are not computed."""
...@@ -165,6 +177,10 @@ class SequenceData: ...@@ -165,6 +177,10 @@ class SequenceData:
def get_output_token_ids(self) -> int: def get_output_token_ids(self) -> int:
return self.output_token_ids return self.output_token_ids
@property
def stage(self) -> SequenceStage:
return self._stage
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceData(" return (f"SequenceData("
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
...@@ -234,7 +250,7 @@ class Sequence: ...@@ -234,7 +250,7 @@ class Sequence:
def reset_state_for_recompute(self): def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation.""" """Reset the sequence states for recomputation."""
self.data.reset_num_computed_tokens() self.data.reset_state_for_recompute()
def _append_logical_block(self) -> None: def _append_logical_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
...@@ -320,6 +336,23 @@ class Sequence: ...@@ -320,6 +336,23 @@ class Sequence:
new_seq.seq_id = new_seq_id new_seq.seq_id = new_seq_id
return new_seq return new_seq
def get_num_new_tokens(self) -> int:
"""Get the number of new tokens to be computed.
Args:
remainig_token_budget: The remaining token budgets.
Returns:
The new number of tokens to be computed. I.e., 1 for decode, prompt
size for prefill. If there's not enough remainig_token_budget, it
can return the chunked number of new tokens.
"""
if self.data.stage == SequenceStage.DECODE:
return 1
return self.data.get_num_uncomputed_tokens()
def is_prefill(self) -> bool:
return self.data.stage == SequenceStage.PREFILL
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, " return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, " f"status={self.status.name}, "
...@@ -461,14 +494,14 @@ class SequenceGroup: ...@@ -461,14 +494,14 @@ class SequenceGroup:
def update_num_computed_tokens(self, num_new_computed_tokens: int): def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far.""" """Update number of tokens computed so far."""
for seq in self.seqs_dict.values(): for seq in self.seqs_dict.values():
seq.data.update_num_computed_tokens(num_new_computed_tokens) if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
# All sequences in the group should have the same prompt, so the num_uncomputed_tokens = 0
# number of unfinished prefill tokens are the same across all for seq in self.get_seqs():
# sequences. num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return list( return num_uncomputed_tokens
self.seqs_dict.values())[0].data.get_num_uncomputed_tokens()
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status)) return len(self.get_seqs(status))
...@@ -497,6 +530,10 @@ class SequenceGroup: ...@@ -497,6 +530,10 @@ class SequenceGroup:
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.get_seqs()) return all(seq.is_finished() for seq in self.get_seqs())
def is_prefill(self) -> bool:
# Every sequences should be in the same stage.
return self.get_seqs()[0].is_prefill()
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, " return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, " f"sampling_params={self.sampling_params}, "
...@@ -513,8 +550,8 @@ class SequenceGroupMetadata: ...@@ -513,8 +550,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
token_chunk_size: The number of tokens to be processed. None if token_chunk_size: The number of tokens to be processed (per sequence).
chunking is not required. None if chunking is not required.
state: Internal state tied to this sequence group. state: Internal state tied to this sequence group.
lora_request: LoRA request. lora_request: LoRA request.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
......
...@@ -222,7 +222,6 @@ class ModelRunner: ...@@ -222,7 +222,6 @@ class ModelRunner:
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prefill_end))) input_positions.extend(list(range(computed_len, prefill_end)))
lora_id = seq_group_metadata.lora_int_id lora_id = seq_group_metadata.lora_int_id
if lora_id > 0: if lora_id > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment