Unverified Commit dafd924c authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

Raise error for long prompt (#273)

parent 598dc4b7
...@@ -186,14 +186,18 @@ class SchedulerConfig: ...@@ -186,14 +186,18 @@ class SchedulerConfig:
a single iteration. a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single max_num_seqs: Maximum number of sequences to be processed in a single
iteration. iteration.
max_seq_len: Maximum length of a sequence (including prompt
and generated text).
""" """
def __init__( def __init__(
self, self,
max_num_batched_tokens: int, max_num_batched_tokens: int,
max_num_seqs: int, max_num_seqs: int,
max_seq_len: int
) -> None: ) -> None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_seq_len = max_seq_len
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
......
...@@ -102,11 +102,12 @@ class Scheduler: ...@@ -102,11 +102,12 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: Dict[int, List[int]] = {}
ignored_seq_groups: List[SequenceGroup] = []
# Fix the current time. # Fix the current time.
now = time.time() now = time.time()
...@@ -187,12 +188,24 @@ class Scheduler: ...@@ -187,12 +188,24 @@ class Scheduler:
# If the sequence group has been preempted in this step, stop. # If the sequence group has been preempted in this step, stop.
if seq_group in preempted: if seq_group in preempted:
break break
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens >= self.scheduler_config.max_seq_len:
logger.warn(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
" and exceeds limit of "
f"{self.scheduler_config.max_seq_len}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop. # If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group): if not self.block_manager.can_allocate(seq_group):
break break
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if (num_batched_tokens + num_prompt_tokens if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens): > self.scheduler_config.max_num_batched_tokens):
break break
...@@ -218,7 +231,7 @@ class Scheduler: ...@@ -218,7 +231,7 @@ class Scheduler:
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
) )
if not self.log_stats: if not self.log_stats:
return scheduler_outputs, prompt_group_ids return scheduler_outputs, prompt_group_ids, ignored_seq_groups
# TODO(woosuk): Move the below code to the engine. # TODO(woosuk): Move the below code to the engine.
now = time.time() now = time.time()
...@@ -258,13 +271,13 @@ class Scheduler: ...@@ -258,13 +271,13 @@ class Scheduler:
f"Pending: {len(self.waiting)} reqs, " f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids return scheduler_outputs, prompt_group_ids, ignored_seq_groups
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_outputs, prompt_group_ids = self._schedule() scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
...@@ -286,7 +299,7 @@ class Scheduler: ...@@ -286,7 +299,7 @@ class Scheduler:
block_tables=block_tables, block_tables=block_tables,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
def update( def update(
self, self,
......
...@@ -123,8 +123,12 @@ class EngineArgs: ...@@ -123,8 +123,12 @@ class EngineArgs:
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray) self.worker_use_ray)
max_seq_len = min(
self.max_num_batched_tokens,
getattr(model_config.hf_config, "max_position_embeddings",
float("inf")))
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs) self.max_num_seqs, max_seq_len)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config
......
...@@ -226,8 +226,8 @@ class LLMEngine: ...@@ -226,8 +226,8 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty(): if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
# Nothing to do. # Nothing to do.
return [] return []
...@@ -251,7 +251,7 @@ class LLMEngine: ...@@ -251,7 +251,7 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in seq_groups: for seq_group in seq_groups + ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
return request_outputs return request_outputs
...@@ -288,6 +288,12 @@ class LLMEngine: ...@@ -288,6 +288,12 @@ class LLMEngine:
if stopped: if stopped:
continue continue
# Check if the sequence has reached max_seq_len.
if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has reached max_tokens. # Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens: if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq( self.scheduler.free_seq(
......
...@@ -13,6 +13,7 @@ class SequenceStatus(enum.Enum): ...@@ -13,6 +13,7 @@ class SequenceStatus(enum.Enum):
FINISHED_STOPPED = enum.auto() FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto() FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
@staticmethod @staticmethod
def is_finished(status: "SequenceStatus") -> bool: def is_finished(status: "SequenceStatus") -> bool:
...@@ -20,6 +21,7 @@ class SequenceStatus(enum.Enum): ...@@ -20,6 +21,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED, SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED
] ]
@staticmethod @staticmethod
...@@ -30,6 +32,8 @@ class SequenceStatus(enum.Enum): ...@@ -30,6 +32,8 @@ class SequenceStatus(enum.Enum):
finish_reason = "length" finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED: elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort" finish_reason = "abort"
elif status == SequenceStatus.FINISHED_IGNORED:
finish_reason = "length"
else: else:
finish_reason = None finish_reason = None
return finish_reason return finish_reason
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment