Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
e21d7687
Unverified
Commit
e21d7687
authored
Sep 17, 2023
by
陈序
Committed by
GitHub
Sep 17, 2023
Browse files
Fix hanging when prompt exceeds limit (#1029)
parent
ff36139f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
13 deletions
+10
-13
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+9
-12
No files found.
vllm/core/scheduler.py
View file @
e21d7687
...
...
@@ -175,7 +175,7 @@ class Scheduler:
num_curr_seqs
+=
num_new_seqs
scheduled
.
append
(
seq_group
)
if
scheduled
:
if
scheduled
or
ignored_seq_groups
:
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled
,
prompt_run
=
True
,
...
...
vllm/engine/llm_engine.py
View file @
e21d7687
...
...
@@ -294,14 +294,12 @@ class LLMEngine:
def
_schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
Optional
[
List
[
RequestOutput
]]
]
:
List
[
RequestOutput
]]:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
scheduler_outputs
.
is_empty
():
return
seq_group_metadata_list
,
scheduler_outputs
,
[
RequestOutput
.
from_seq_group
(
seq_group
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
]
return
seq_group_metadata_list
,
scheduler_outputs
,
None
return
seq_group_metadata_list
,
scheduler_outputs
,
[
RequestOutput
.
from_seq_group
(
seq_group
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
]
def
_check_beam_search_early_stopping
(
self
,
...
...
@@ -545,10 +543,9 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(
seq_group_metadata_list
,
scheduler_outputs
,
early_return
)
=
self
.
_schedule
()
if
early_return
is
not
None
:
return
early_return
seq_group_metadata_list
,
scheduler_outputs
,
ignored
=
self
.
_schedule
()
if
scheduler_outputs
.
is_empty
():
return
ignored
# Execute the model.
output
=
self
.
_run_workers
(
...
...
@@ -559,7 +556,7 @@ class LLMEngine:
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
+
ignored
def
_log_system_stats
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment