Unverified Commit b4b195b3 authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

fix max seq len (#489)

parent 20b0d88d
...@@ -204,10 +204,10 @@ class SchedulerConfig: ...@@ -204,10 +204,10 @@ class SchedulerConfig:
""" """
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int, def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_seq_len: int) -> None: max_model_len: int) -> None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_seq_len = max_seq_len self.max_model_len = max_model_len
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
......
...@@ -190,7 +190,9 @@ class Scheduler: ...@@ -190,7 +190,9 @@ class Scheduler:
break break
num_prompt_tokens = seq_group.get_seqs()[0].get_len() num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens > self.scheduler_config.max_seq_len: if num_prompt_tokens > min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens):
logger.warning( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prompt_tokens} tokens) is too long"
" and exceeds limit of " " and exceeds limit of "
......
...@@ -155,11 +155,10 @@ class EngineArgs: ...@@ -155,11 +155,10 @@ class EngineArgs:
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray) self.worker_use_ray)
model_max_len = getattr(model_config.hf_config, max_model_len = getattr(model_config.hf_config,
'max_position_embeddings', float('inf')) 'max_position_embeddings', float('inf'))
max_seq_len = min(self.max_num_batched_tokens, model_max_len)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, max_seq_len) self.max_num_seqs, max_model_len)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config
......
...@@ -300,8 +300,7 @@ class LLMEngine: ...@@ -300,8 +300,7 @@ class LLMEngine:
continue continue
# Check if the sequence has reached max_seq_len. # Check if the sequence has reached max_seq_len.
if (seq.get_len() > if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq( self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED) seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment