Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
dafd924c
Unverified
Commit
dafd924c
authored
Jun 30, 2023
by
Lily Liu
Committed by
GitHub
Jun 30, 2023
Browse files
Raise error for long prompt (#273)
parent
598dc4b7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
11 deletions
+42
-11
vllm/config.py
vllm/config.py
+4
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+20
-7
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+5
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+9
-3
vllm/sequence.py
vllm/sequence.py
+4
-0
No files found.
vllm/config.py
View file @
dafd924c
...
@@ -186,14 +186,18 @@ class SchedulerConfig:
...
@@ -186,14 +186,18 @@ class SchedulerConfig:
a single iteration.
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
iteration.
max_seq_len: Maximum length of a sequence (including prompt
and generated text).
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
max_num_batched_tokens
:
int
,
max_num_batched_tokens
:
int
,
max_num_seqs
:
int
,
max_num_seqs
:
int
,
max_seq_len
:
int
)
->
None
:
)
->
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
self
.
max_seq_len
=
max_seq_len
_STR_DTYPE_TO_TORCH_DTYPE
=
{
_STR_DTYPE_TO_TORCH_DTYPE
=
{
...
...
vllm/core/scheduler.py
View file @
dafd924c
...
@@ -102,11 +102,12 @@ class Scheduler:
...
@@ -102,11 +102,12 @@ class Scheduler:
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
def
get_num_unfinished_seq_groups
(
self
)
->
int
:
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
+
len
(
self
.
swapped
)
def
_schedule
(
self
)
->
Tuple
[
SchedulerOutputs
,
List
[
str
]]:
def
_schedule
(
self
)
->
Tuple
[
SchedulerOutputs
,
List
[
str
]
,
List
[
SequenceGroup
]
]:
# Blocks that need to be swaped or copied before model execution.
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
{}
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
# Fix the current time.
# Fix the current time.
now
=
time
.
time
()
now
=
time
.
time
()
...
@@ -187,12 +188,24 @@ class Scheduler:
...
@@ -187,12 +188,24 @@ class Scheduler:
# If the sequence group has been preempted in this step, stop.
# If the sequence group has been preempted in this step, stop.
if
seq_group
in
preempted
:
if
seq_group
in
preempted
:
break
break
num_prompt_tokens
=
seq_group
.
get_seqs
()[
0
].
get_len
()
if
num_prompt_tokens
>=
self
.
scheduler_config
.
max_seq_len
:
logger
.
warn
(
f
"Input prompt (
{
num_prompt_tokens
}
tokens) is too long"
" and exceeds limit of "
f
"
{
self
.
scheduler_config
.
max_seq_len
}
"
)
for
seq
in
seq_group
.
get_seqs
():
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
self
.
waiting
.
pop
(
0
)
break
# If the sequence group cannot be allocated, stop.
# If the sequence group cannot be allocated, stop.
if
not
self
.
block_manager
.
can_allocate
(
seq_group
):
if
not
self
.
block_manager
.
can_allocate
(
seq_group
):
break
break
# If the number of batched tokens exceeds the limit, stop.
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens
=
seq_group
.
get_seqs
()[
0
].
get_len
()
if
(
num_batched_tokens
+
num_prompt_tokens
if
(
num_batched_tokens
+
num_prompt_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
>
self
.
scheduler_config
.
max_num_batched_tokens
):
break
break
...
@@ -218,7 +231,7 @@ class Scheduler:
...
@@ -218,7 +231,7 @@ class Scheduler:
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
)
)
if
not
self
.
log_stats
:
if
not
self
.
log_stats
:
return
scheduler_outputs
,
prompt_group_ids
return
scheduler_outputs
,
prompt_group_ids
,
ignored_seq_groups
# TODO(woosuk): Move the below code to the engine.
# TODO(woosuk): Move the below code to the engine.
now
=
time
.
time
()
now
=
time
.
time
()
...
@@ -258,13 +271,13 @@ class Scheduler:
...
@@ -258,13 +271,13 @@ class Scheduler:
f
"Pending:
{
len
(
self
.
waiting
)
}
reqs, "
f
"Pending:
{
len
(
self
.
waiting
)
}
reqs, "
f
"GPU KV cache usage:
{
gpu_cache_usage
*
100
:.
1
f
}
%, "
f
"GPU KV cache usage:
{
gpu_cache_usage
*
100
:.
1
f
}
%, "
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
return
scheduler_outputs
,
prompt_group_ids
return
scheduler_outputs
,
prompt_group_ids
,
ignored_seq_groups
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
List
[
SequenceGroup
]
]:
# Schedule sequence groups.
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
# such as self.running, self.swapped, and self.waiting.
scheduler_outputs
,
prompt_group_ids
=
self
.
_schedule
()
scheduler_outputs
,
prompt_group_ids
,
ignored_seq_groups
=
self
.
_schedule
()
# Create input data structures.
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
...
@@ -286,7 +299,7 @@ class Scheduler:
...
@@ -286,7 +299,7 @@ class Scheduler:
block_tables
=
block_tables
,
block_tables
=
block_tables
,
)
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
return
seq_group_metadata_list
,
scheduler_outputs
,
ignored_seq_groups
def
update
(
def
update
(
self
,
self
,
...
...
vllm/engine/arg_utils.py
View file @
dafd924c
...
@@ -123,8 +123,12 @@ class EngineArgs:
...
@@ -123,8 +123,12 @@ class EngineArgs:
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
)
self
.
worker_use_ray
)
max_seq_len
=
min
(
self
.
max_num_batched_tokens
,
getattr
(
model_config
.
hf_config
,
"max_position_embeddings"
,
float
(
"inf"
)))
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
)
self
.
max_num_seqs
,
max_seq_len
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
...
...
vllm/engine/llm_engine.py
View file @
dafd924c
...
@@ -226,8 +226,8 @@ class LLMEngine:
...
@@ -226,8 +226,8 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
the sequences and returns the newly generated results.
"""
"""
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
seq_group_metadata_list
,
scheduler_outputs
,
ignored_seq_groups
=
self
.
scheduler
.
schedule
()
if
(
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
():
if
(
not
seq_group_metadata_list
)
and
scheduler_outputs
.
is_empty
()
and
(
not
ignored_seq_groups
)
:
# Nothing to do.
# Nothing to do.
return
[]
return
[]
...
@@ -251,7 +251,7 @@ class LLMEngine:
...
@@ -251,7 +251,7 @@ class LLMEngine:
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
seq_groups
:
for
seq_group
in
seq_groups
+
ignored_seq_groups
:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
return
request_outputs
return
request_outputs
...
@@ -288,6 +288,12 @@ class LLMEngine:
...
@@ -288,6 +288,12 @@ class LLMEngine:
if
stopped
:
if
stopped
:
continue
continue
# Check if the sequence has reached max_seq_len.
if
(
seq
.
get_len
()
>=
self
.
scheduler
.
scheduler_config
.
max_seq_len
):
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
)
continue
# Check if the sequence has reached max_tokens.
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
self
.
scheduler
.
free_seq
(
self
.
scheduler
.
free_seq
(
...
...
vllm/sequence.py
View file @
dafd924c
...
@@ -13,6 +13,7 @@ class SequenceStatus(enum.Enum):
...
@@ -13,6 +13,7 @@ class SequenceStatus(enum.Enum):
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
FINISHED_ABORTED
=
enum
.
auto
()
FINISHED_ABORTED
=
enum
.
auto
()
FINISHED_IGNORED
=
enum
.
auto
()
@
staticmethod
@
staticmethod
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
...
@@ -20,6 +21,7 @@ class SequenceStatus(enum.Enum):
...
@@ -20,6 +21,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_ABORTED
,
SequenceStatus
.
FINISHED_ABORTED
,
SequenceStatus
.
FINISHED_IGNORED
]
]
@
staticmethod
@
staticmethod
...
@@ -30,6 +32,8 @@ class SequenceStatus(enum.Enum):
...
@@ -30,6 +32,8 @@ class SequenceStatus(enum.Enum):
finish_reason
=
"length"
finish_reason
=
"length"
elif
status
==
SequenceStatus
.
FINISHED_ABORTED
:
elif
status
==
SequenceStatus
.
FINISHED_ABORTED
:
finish_reason
=
"abort"
finish_reason
=
"abort"
elif
status
==
SequenceStatus
.
FINISHED_IGNORED
:
finish_reason
=
"length"
else
:
else
:
finish_reason
=
None
finish_reason
=
None
return
finish_reason
return
finish_reason
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment