Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4fdc581f
Unverified
Commit
4fdc581f
authored
Oct 24, 2024
by
youkaichao
Committed by
GitHub
Oct 24, 2024
Browse files
[core] simplify seq group code (#9569)
Co-authored-by:
Zhuohan Li
<
zhuohan123@gmail.com
>
parent
3770071e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
62 additions
and
566 deletions
+62
-566
tests/core/test_chunked_prefill_scheduler.py
tests/core/test_chunked_prefill_scheduler.py
+0
-153
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+1
-203
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+19
-21
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+22
-105
vllm/sequence.py
vllm/sequence.py
+19
-83
No files found.
tests/core/test_chunked_prefill_scheduler.py
View file @
4fdc581f
...
@@ -4,7 +4,6 @@ from unittest.mock import MagicMock
...
@@ -4,7 +4,6 @@ from unittest.mock import MagicMock
import
pytest
# noqa
import
pytest
# noqa
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
from
vllm.sequence
import
Logprob
,
SequenceGroup
from
vllm.sequence
import
Logprob
,
SequenceGroup
...
@@ -347,158 +346,6 @@ def test_prompt_limit_exceed():
...
@@ -347,158 +346,6 @@ def test_prompt_limit_exceed():
assert
out
.
ignored_seq_groups
[
0
]
==
seq_group
assert
out
.
ignored_seq_groups
[
0
]
==
seq_group
def
test_swap
():
"""Verify swapping works with chunked prefill requests"""
block_size
=
4
max_seqs
=
30
max_model_len
=
200
max_num_batched_tokens
=
30
scheduler_config
=
SchedulerConfig
(
"generate"
,
max_num_batched_tokens
,
max_seqs
,
max_model_len
,
enable_chunked_prefill
=
True
,
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
16
cache_config
.
num_gpu_blocks
=
16
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
None
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
,
block_size
=
block_size
)
scheduler
.
add_seq_group
(
seq_group
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
# The request is chunked.
# prefill scheduled now.
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
num_prefill_groups
==
1
assert
seq_group
.
is_prefill
()
assert
out
.
num_batched_tokens
==
max_num_batched_tokens
# The last request should be swapped out.
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
return
seq_group
.
request_id
!=
"1"
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
cannot_append_second_group
)
# The running prefill is now swapped.
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
0
assert
out
.
num_batched_tokens
==
0
assert
out
.
blocks_to_swap_out
!=
[]
assert
out
.
blocks_to_swap_in
==
[]
# Add 1 more task. Swap should be prioritized over new prefill.
_
,
seq_group
=
create_dummy_prompt
(
"2"
,
prompt_length
=
60
)
scheduler
.
add_seq_group
(
seq_group
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
30
assert
out
.
blocks_to_swap_in
!=
[]
assert
out
.
blocks_to_swap_out
==
[]
def
test_running_prefill_prioritized_over_swap
():
block_size
=
4
max_seqs
=
30
max_model_len
=
200
max_num_batched_tokens
=
30
scheduler_config
=
SchedulerConfig
(
"generate"
,
max_num_batched_tokens
,
max_seqs
,
max_model_len
,
enable_chunked_prefill
=
True
,
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
32
cache_config
.
num_gpu_blocks
=
32
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
None
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
,
block_size
=
block_size
)
scheduler
.
add_seq_group
(
seq_group
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
# The request is chunked.
# prefill scheduled now.
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
num_prefill_groups
==
1
assert
seq_group
.
is_prefill
()
assert
out
.
num_batched_tokens
==
max_num_batched_tokens
# The request should be swapped out.
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
return
seq_group
.
request_id
!=
"1"
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
cannot_append_second_group
)
# The running prefill is now swapped.
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
0
assert
out
.
num_batched_tokens
==
0
assert
out
.
blocks_to_swap_out
!=
[]
assert
out
.
blocks_to_swap_in
==
[]
# Add 1 more task. Swap is not possible, so prefill is running.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
LATER
_
,
seq_group2
=
create_dummy_prompt
(
"2"
,
prompt_length
=
60
,
block_size
=
block_size
)
scheduler
.
add_seq_group
(
seq_group2
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
30
assert
out
.
blocks_to_swap_in
==
[]
assert
out
.
blocks_to_swap_out
==
[]
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
# Now although swap is possible, running prefill is prioritized.
scheduler
.
block_manager
.
can_swap_in
.
return_value
=
AllocStatus
.
OK
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
30
assert
out
.
blocks_to_swap_in
==
[]
assert
out
.
blocks_to_swap_out
==
[]
assert
not
seq_group2
.
is_prefill
()
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
append_new_token
(
seq_group2
,
1
)
# Decoding is prioritized.
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
1
assert
out
.
blocks_to_swap_in
==
[]
assert
out
.
blocks_to_swap_out
==
[]
assert
not
seq_group2
.
is_prefill
()
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
append_new_token
(
seq_group2
,
1
)
# Since we abort the sequence group, we can finally swap.
scheduler
.
abort_seq_group
(
seq_group2
.
request_id
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
num_batched_tokens
==
30
assert
out
.
blocks_to_swap_in
!=
[]
assert
out
.
blocks_to_swap_out
==
[]
def
test_chunked_prefill_preempt
():
def
test_chunked_prefill_preempt
():
"""Verify preempt works with chunked prefill requests"""
"""Verify preempt works with chunked prefill requests"""
block_size
=
4
block_size
=
4
...
...
tests/core/test_scheduler.py
View file @
4fdc581f
...
@@ -10,7 +10,7 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
...
@@ -10,7 +10,7 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.core.scheduler
import
Scheduler
,
SchedulingBudget
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
SequenceGroup
from
.utils
import
(
append_new_token
,
append_new_token_seq_group
,
from
.utils
import
(
append_new_token
,
append_new_token_seq_group
,
create_dummy_prompt
,
get_sequence_groups
,
create_dummy_prompt
,
get_sequence_groups
,
...
@@ -296,55 +296,6 @@ def test_scheduler_delay_factor():
...
@@ -296,55 +296,6 @@ def test_scheduler_delay_factor():
append_new_token
(
out
,
1
)
append_new_token
(
out
,
1
)
def
test_swapped_out_prioritized
():
block_size
=
4
scheduler
=
initialize_scheduler
(
max_num_seqs
=
6
,
block_size
=
block_size
,
num_cpu_blocks
=
64
,
num_gpu_blocks
=
64
)
# best_of=2 * 3 == 6 sequences.
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
,
block_size
=
block_size
)
scheduler
.
add_seq_group
(
seq_group
)
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
# prefill scheduled now.
assert
len
(
out
.
scheduled_seq_groups
)
==
3
append_new_token
(
out
,
1
)
# The last request should be swapped out.
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
return
seq_group
.
request_id
!=
"2"
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
cannot_append_second_group
)
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
2
assert
out
.
num_batched_tokens
==
2
assert
out
.
blocks_to_swap_out
!=
[]
assert
out
.
blocks_to_swap_in
==
[]
append_new_token
(
out
,
1
)
# Add 1 more task. Swap should be prioritized over prefill.
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
,
block_size
=
block_size
)
scheduler
.
add_seq_group
(
seq_group
)
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
append_new_token
(
out
,
1
)
assert
len
(
out
.
scheduled_seq_groups
)
==
3
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
3
assert
out
.
blocks_to_swap_in
!=
[]
assert
out
.
blocks_to_swap_out
==
[]
def
initialize_scheduler
(
def
initialize_scheduler
(
*
,
*
,
max_num_seqs
=
1000
,
max_num_seqs
=
1000
,
...
@@ -646,60 +597,6 @@ def test_decode_schedule_preempted():
...
@@ -646,60 +597,6 @@ def test_decode_schedule_preempted():
assert
output
.
blocks_to_copy
==
[]
assert
output
.
blocks_to_copy
==
[]
def
test_decode_swap_beam_search
():
"""
Test best_of > 1 swap out blocks
"""
block_size
=
4
scheduler
=
initialize_scheduler
(
block_size
=
block_size
,
num_gpu_blocks
=
64
,
num_cpu_blocks
=
64
)
curr_loras
=
None
budget
=
create_token_budget
()
for
i
in
range
(
3
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
,
block_size
=
block_size
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_add_seq_group_to_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
budget
.
add_num_seqs
(
seq_group
.
request_id
,
seq_group
.
get_max_num_running_seqs
())
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
seq_group
.
num_seqs
(
SequenceStatus
.
RUNNING
))
# The last request should be swapped out.
scheduler
.
block_manager
.
can_append_slots
=
MagicMock
()
def
cannot_append_second_group
(
seq_group
,
num_lookahead_slots
):
return
seq_group
.
request_id
!=
"2"
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
cannot_append_second_group
)
scheduler
.
block_manager
.
swap_out
=
MagicMock
()
expected_swap_mapping
=
[(
"5"
,
"7"
)]
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
output
=
scheduler
.
_schedule_running
(
budget
,
curr_loras
)
remainig_running
=
scheduler
.
running
assert
len
(
remainig_running
)
==
0
assert
len
(
output
.
decode_seq_groups
)
==
2
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
output
.
decode_seq_groups
[
0
].
seq_group
.
request_id
==
"0"
assert
output
.
decode_seq_groups
[
1
].
seq_group
.
request_id
==
"1"
assert
len
(
output
.
preempted
)
==
0
assert
len
(
output
.
swapped_out
)
==
1
# Budget should refledct preempted requests.
assert
budget
.
num_batched_tokens
==
2
# since there are 2 sequences, 2 should be subtracted.
assert
budget
.
num_curr_seqs
==
4
# Both should be preempted, not swapped.
assert
output
.
blocks_to_swap_out
==
expected_swap_mapping
# Nothing is copied.
assert
output
.
blocks_to_copy
==
[]
def
test_schedule_decode_blocks_to_copy_update
():
def
test_schedule_decode_blocks_to_copy_update
():
"""
"""
Verify blocks_to_copy is updated.
Verify blocks_to_copy is updated.
...
@@ -736,105 +633,6 @@ def test_schedule_decode_blocks_to_copy_update():
...
@@ -736,105 +633,6 @@ def test_schedule_decode_blocks_to_copy_update():
assert
output
.
blocks_to_copy
==
[(
2
,
3
)]
assert
output
.
blocks_to_copy
==
[(
2
,
3
)]
def
test_schedule_swapped_simple
():
block_size
=
4
scheduler
=
initialize_scheduler
(
block_size
=
block_size
)
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
4
,
best_of
=
2
,
block_size
=
block_size
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
4
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_add_seq_group_to_swapped
(
seq_group
)
budget
=
create_token_budget
()
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
0
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
# swap in is the reverse of swap out
blocks_to_swap_in_reverse
=
[]
for
swapin
,
swapout
in
output
.
blocks_to_swap_in
:
blocks_to_swap_in_reverse
.
append
((
swapout
,
swapin
))
assert
blocks_to_swap_out
==
blocks_to_swap_in_reverse
def
test_schedule_swapped_max_token_budget
():
block_size
=
4
scheduler
=
initialize_scheduler
(
block_size
=
block_size
,
num_cpu_blocks
=
32
,
num_gpu_blocks
=
32
)
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_add_seq_group_to_swapped
(
seq_group
)
budget
=
create_token_budget
(
token_budget
=
1
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
2
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
# Verify num_batched_tokens are respected.
budget
=
create_token_budget
(
token_budget
=
1
)
add_token_budget
(
budget
,
1
,
0
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
1
assert
budget
.
num_batched_tokens
==
1
assert
budget
.
num_curr_seqs
==
0
assert
len
(
output
.
decode_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
def
test_schedule_swapped_max_seqs
():
block_size
=
4
scheduler
=
initialize_scheduler
(
block_size
=
block_size
,
num_cpu_blocks
=
64
,
num_gpu_blocks
=
64
)
curr_loras
=
None
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
4
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
block_size
=
4
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_add_seq_group_to_swapped
(
seq_group
)
budget
=
create_token_budget
(
max_num_seqs
=
2
)
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_curr_seqs
==
2
assert
len
(
output
.
decode_seq_groups
)
==
2
assert
len
(
output
.
prefill_seq_groups
)
==
0
# Verify num_curr_seqs are respected.
output
=
scheduler
.
_schedule_swapped
(
budget
,
curr_loras
)
remaining_swapped
=
scheduler
.
swapped
assert
len
(
remaining_swapped
)
==
2
assert
budget
.
num_batched_tokens
==
2
assert
budget
.
num_curr_seqs
==
2
assert
len
(
output
.
decode_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
def
test_schedule_swapped_max_loras
():
def
test_schedule_swapped_max_loras
():
block_size
=
4
block_size
=
4
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_loras
=
1
)
...
...
vllm/core/scheduler.py
View file @
4fdc581f
...
@@ -290,7 +290,7 @@ def scheduler_running_outputs_builder():
...
@@ -290,7 +290,7 @@ def scheduler_running_outputs_builder():
def
scheduled_seq_group_builder
():
def
scheduled_seq_group_builder
():
return
ScheduledSequenceGroup
(
SequenceGroup
(
""
,
[],
-
1
),
return
ScheduledSequenceGroup
(
SequenceGroup
.
__new__
(
SequenceGroup
),
token_chunk_size
=
0
)
token_chunk_size
=
0
)
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
...
...
vllm/engine/llm_engine.py
View file @
4fdc581f
...
@@ -647,10 +647,24 @@ class LLMEngine:
...
@@ -647,10 +647,24 @@ class LLMEngine:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
SequenceGroup
:
)
->
Optional
[
SequenceGroup
]
:
"""Add a processed request to the engine's request pool.
"""Add a processed request to the engine's request pool.
return the created sequence group.
return the created sequence group.
"""
"""
if
isinstance
(
params
,
SamplingParams
)
and
params
.
n
>
1
:
ParallelSampleSequenceGroup
.
add_request
(
request_id
,
self
,
params
,
processed_inputs
=
processed_inputs
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
return
None
self
.
_validate_model_inputs
(
processed_inputs
)
self
.
_validate_model_inputs
(
processed_inputs
)
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
...
@@ -721,7 +735,7 @@ class LLMEngine:
...
@@ -721,7 +735,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
Optional
[
SequenceGroup
]
:
)
->
None
:
...
...
@
overload
@
overload
...
@@ -735,7 +749,7 @@ class LLMEngine:
...
@@ -735,7 +749,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
Optional
[
SequenceGroup
]
:
)
->
None
:
...
...
@
deprecate_kwargs
(
@
deprecate_kwargs
(
...
@@ -754,7 +768,7 @@ class LLMEngine:
...
@@ -754,7 +768,7 @@ class LLMEngine:
priority
:
int
=
0
,
priority
:
int
=
0
,
*
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
Optional
[
SequenceGroup
]
:
)
->
None
:
"""Add a request to the engine's request pool.
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
The request is added to the request pool and will be processed by the
...
@@ -798,22 +812,6 @@ class LLMEngine:
...
@@ -798,22 +812,6 @@ class LLMEngine:
>>> # continue the request processing
>>> # continue the request processing
>>> ...
>>> ...
"""
"""
if
isinstance
(
params
,
SamplingParams
)
and
params
.
n
>
1
:
ParallelSampleSequenceGroup
.
add_request
(
request_id
,
self
,
params
,
prompt
=
prompt
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
inputs
=
inputs
,
)
return
None
if
inputs
is
not
None
:
if
inputs
is
not
None
:
prompt
=
inputs
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
assert
prompt
is
not
None
and
params
is
not
None
...
@@ -844,7 +842,7 @@ class LLMEngine:
...
@@ -844,7 +842,7 @@ class LLMEngine:
processed_inputs
[
"mm_processor_kwargs"
]
=
preprocessed_inputs
.
get
(
processed_inputs
[
"mm_processor_kwargs"
]
=
preprocessed_inputs
.
get
(
"mm_processor_kwargs"
)
"mm_processor_kwargs"
)
return
self
.
_add_processed_request
(
self
.
_add_processed_request
(
request_id
=
request_id
,
request_id
=
request_id
,
processed_inputs
=
processed_inputs
,
processed_inputs
=
processed_inputs
,
params
=
params
,
params
=
params
,
...
...
vllm/engine/output_processor/single_step.py
View file @
4fdc581f
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
List
from
vllm.config
import
SchedulerConfig
from
vllm.config
import
SchedulerConfig
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
...
@@ -6,9 +6,8 @@ from vllm.engine.output_processor.interfaces import (
...
@@ -6,9 +6,8 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor
)
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Sequence
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SequenceGroup
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceGroupOutput
)
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -114,104 +113,22 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -114,104 +113,22 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
outputs
:
SequenceGroupOutput
,
outputs
:
SequenceGroupOutput
,
is_async
:
bool
)
->
None
:
is_async
:
bool
)
->
None
:
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
.
n
==
1
:
# only have one output sample
sample
=
outputs
.
samples
[
0
]
sample
=
outputs
.
samples
[
0
]
seq
=
seq_group
.
first_seq
# only have one sequence
if
not
is_async
:
seq
=
seq_group
.
seqs
[
0
]
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
if
not
is_async
:
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
seq
,
sampling_params
)
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
else
:
seq
,
sampling_params
)
new_char_count
=
0
else
:
self
.
stop_checker
.
maybe_stop_sequence
(
new_char_count
=
0
seq
,
self
.
stop_checker
.
maybe_stop_sequence
(
new_char_count
,
seq
,
sampling_params
,
new_char_count
,
lora_req
=
seq_group
.
lora_request
,
sampling_params
,
)
lora_req
=
seq_group
.
lora_request
,
if
seq
.
is_finished
():
)
for
scheduler
in
self
.
scheduler
:
if
seq
.
is_finished
():
scheduler
.
free_seq
(
seq
)
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
return
# TODO: Add support for async for beam search
assert
not
is_async
# Process samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
parent_child_dict
:
Dict
[
int
,
List
[
SequenceOutput
]]
=
{
parent_seq
.
seq_id
:
[]
for
parent_seq
in
parent_seqs
}
for
sample
in
samples
:
# Guard against a KeyError which can occur if the request was
# aborted while the output was generated
if
(
child_list
:
=
parent_child_dict
.
get
(
sample
.
parent_seq_id
))
is
not
None
:
child_list
.
append
(
sample
)
# List of (child, parent)
child_seqs
:
List
[
Tuple
[
Sequence
,
Sequence
]]
=
[]
# Process the child samples for each parent sequence
for
parent
in
parent_seqs
:
child_samples
:
List
[
SequenceOutput
]
=
parent_child_dict
[
parent
.
seq_id
]
if
len
(
child_samples
)
==
0
:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent
.
status
=
SequenceStatus
.
FINISHED_ABORTED
seq_group
.
remove
(
parent
.
seq_id
)
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
parent
)
continue
# Fork the parent sequence if there are multiple child samples.
for
child_sample
in
child_samples
[:
-
1
]:
new_child_seq_id
:
int
=
next
(
self
.
seq_counter
)
child
=
parent
.
fork
(
new_child_seq_id
)
child
.
append_token_id
(
child_sample
.
output_token
,
child_sample
.
logprobs
)
child_seqs
.
append
((
child
,
parent
))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample
=
child_samples
[
-
1
]
parent
.
append_token_id
(
last_child_sample
.
output_token
,
last_child_sample
.
logprobs
)
child_seqs
.
append
((
parent
,
parent
))
for
seq
,
_
in
child_seqs
:
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
seq
,
sampling_params
)
else
:
new_char_count
=
0
self
.
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
,
sampling_params
,
lora_req
=
seq_group
.
lora_request
,
)
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
child_seqs
:
if
seq
is
not
parent
:
seq_group
.
add
(
seq
)
if
not
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
fork_seq
(
parent
,
seq
)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for
seq
,
parent
in
child_seqs
:
if
seq
is
parent
and
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
return
vllm/sequence.py
View file @
4fdc581f
...
@@ -681,6 +681,7 @@ class SequenceGroup:
...
@@ -681,6 +681,7 @@ class SequenceGroup:
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
seqs
=
seqs
self
.
seqs
=
seqs
self
.
first_seq
=
seqs
[
0
]
self
.
arrival_time
=
arrival_time
self
.
arrival_time
=
arrival_time
self
.
is_single_seq
=
len
(
seqs
)
==
1
self
.
is_single_seq
=
len
(
seqs
)
==
1
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
...
@@ -705,15 +706,11 @@ class SequenceGroup:
...
@@ -705,15 +706,11 @@ class SequenceGroup:
@
property
@
property
def
prompt
(
self
)
->
Optional
[
str
]:
def
prompt
(
self
)
->
Optional
[
str
]:
# All sequences in the group should have the same prompt.
return
self
.
first_seq
.
prompt
# We use the prompt of an arbitrary sequence.
return
self
.
seqs
[
0
].
prompt
@
property
@
property
def
prompt_token_ids
(
self
)
->
List
[
int
]:
def
prompt_token_ids
(
self
)
->
List
[
int
]:
# All sequences in the group should have the same prompt.
return
self
.
first_seq
.
prompt_token_ids
# We use the prompt of an arbitrary sequence.
return
self
.
seqs
[
0
].
prompt_token_ids
@
property
@
property
def
encoder_prompt
(
self
)
->
Optional
[
str
]:
def
encoder_prompt
(
self
)
->
Optional
[
str
]:
...
@@ -733,17 +730,11 @@ class SequenceGroup:
...
@@ -733,17 +730,11 @@ class SequenceGroup:
@
property
@
property
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
# All sequences in the group should have the same multi-modal data.
return
self
.
first_seq
.
multi_modal_data
# We use the multi-modal data of an arbitrary sequence.
return
self
.
seqs
[
0
].
multi_modal_data
@
property
@
property
def
mm_processor_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
def
mm_processor_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
# As with multi-modal data, all sequences in the group should have the
return
self
.
first_seq
.
mm_processor_kwargs
# same processor kwargs (i.e., mm_processor_kwargs are optionally
# provided per request; note that are independent of whether the model
# decoder-only or an encoder-decoder).
return
self
.
seqs
[
0
].
mm_processor_kwargs
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
...
@@ -808,7 +799,7 @@ class SequenceGroup:
...
@@ -808,7 +799,7 @@ class SequenceGroup:
# in TPOT, rather than recalculating TTFT (since from the )
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
# POV of the user, there is simply a long generation delay.
if
(
self
.
metrics
.
first_token_time
is
None
if
(
self
.
metrics
.
first_token_time
is
None
and
self
.
seqs
[
0
]
.
get_output_len
()
==
1
):
and
self
.
first_seq
.
get_output_len
()
==
1
):
self
.
metrics
.
first_token_time
=
time
self
.
metrics
.
first_token_time
=
time
def
maybe_set_first_scheduled_time
(
self
,
time
:
float
)
->
None
:
def
maybe_set_first_scheduled_time
(
self
,
time
:
float
)
->
None
:
...
@@ -825,18 +816,7 @@ class SequenceGroup:
...
@@ -825,18 +816,7 @@ class SequenceGroup:
def
get_max_num_running_seqs
(
self
)
->
int
:
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
lifetime of the request."""
if
self
.
sampling_params
:
return
0
if
self
.
first_seq
.
is_finished
()
else
1
n
=
self
.
sampling_params
.
n
assert
isinstance
(
n
,
int
)
if
n
>
self
.
num_seqs
():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `n` sequences
# running.
return
n
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return
self
.
num_unfinished_seqs
()
def
get_seqs
(
def
get_seqs
(
self
,
self
,
...
@@ -845,10 +825,7 @@ class SequenceGroup:
...
@@ -845,10 +825,7 @@ class SequenceGroup:
if
status
is
None
:
if
status
is
None
:
return
self
.
seqs
return
self
.
seqs
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
first_seq
.
status
==
status
else
[]
return
self
.
seqs
if
self
.
seqs
[
0
].
status
==
status
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
is_encoder_decoder
(
self
)
->
bool
:
def
is_encoder_decoder
(
self
)
->
bool
:
return
self
.
encoder_seq
is
not
None
return
self
.
encoder_seq
is
not
None
...
@@ -856,29 +833,20 @@ class SequenceGroup:
...
@@ -856,29 +833,20 @@ class SequenceGroup:
def
get_encoder_seq
(
self
)
->
Optional
[
Sequence
]:
def
get_encoder_seq
(
self
)
->
Optional
[
Sequence
]:
return
self
.
encoder_seq
return
self
.
encoder_seq
def
get_unfinished_seqs
(
self
)
->
List
[
Sequence
]:
if
self
.
is_single_seq
:
return
self
.
seqs
if
not
self
.
seqs
[
0
].
is_finished
()
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
not
seq
.
is_finished
()]
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
first_seq
.
is_finished
()
else
[]
return
self
.
seqs
if
self
.
seqs
[
0
].
is_finished
()
else
[]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
is_finished
()]
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
"""Update number of tokens computed so far."""
for
seq
in
self
.
seq
s
:
seq
=
self
.
first_
seq
if
not
seq
.
is_finished
():
if
not
seq
.
is_finished
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
num_uncomputed_tokens
=
0
num_uncomputed_tokens
=
0
for
seq
in
self
.
seq
s
:
seq
=
self
.
first_
seq
if
not
seq
.
is_finished
():
if
not
seq
.
is_finished
():
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
return
num_uncomputed_tokens
return
num_uncomputed_tokens
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
...
@@ -892,46 +860,14 @@ class SequenceGroup:
...
@@ -892,46 +860,14 @@ class SequenceGroup:
return
len
(
self
.
get_seqs
(
status
))
return
len
(
self
.
get_seqs
(
status
))
def
num_unfinished_seqs
(
self
)
->
int
:
if
self
.
is_single_seq
:
return
1
if
not
self
.
seqs
[
0
].
is_finished
()
else
0
return
len
(
self
.
get_unfinished_seqs
())
def
num_finished_seqs
(
self
)
->
int
:
def
num_finished_seqs
(
self
)
->
int
:
if
self
.
is_single_seq
:
return
1
if
self
.
first_seq
.
is_finished
()
else
0
return
1
if
self
.
seqs
[
0
].
is_finished
()
else
0
return
len
(
self
.
get_finished_seqs
())
def
find
(
self
,
seq_id
:
int
)
->
Sequence
:
if
seq_id
not
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
return
self
.
seqs_dict
[
seq_id
]
def
add
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq
.
seq_id
}
already exists."
)
self
.
seqs_dict
[
seq
.
seq_id
]
=
seq
self
.
seqs
.
append
(
seq
)
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
def
remove
(
self
,
seq_id
:
int
)
->
None
:
seq
=
self
.
seqs_dict
.
pop
(
seq_id
,
None
)
if
seq
is
None
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
self
.
seqs
.
remove
(
seq
)
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
if
self
.
is_single_seq
:
return
self
.
first_seq
.
is_finished
()
return
self
.
seqs
[
0
].
is_finished
()
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
is_prefill
(
self
)
->
bool
:
def
is_prefill
(
self
)
->
bool
:
# Every sequence should be in the same stage.
return
self
.
first_seq
.
is_prefill
()
return
self
.
seqs
[
0
].
is_prefill
()
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
...
@@ -1455,7 +1391,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
...
@@ -1455,7 +1391,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
for
i
in
range
(
original_params
.
n
):
for
i
in
range
(
original_params
.
n
):
request_id_i
=
f
"
{
request_id
}
_parallel_sample_
{
i
}
"
request_id_i
=
f
"
{
request_id
}
_parallel_sample_
{
i
}
"
group
.
seq_id_to_index
[
request_id_i
]
=
i
group
.
seq_id_to_index
[
request_id_i
]
=
i
seq_group
=
engine
.
add_request
(
seq_group
=
engine
.
_
add_
processed_
request
(
request_id_i
,
request_id_i
,
params
=
params
,
params
=
params
,
**
kwargs
,
**
kwargs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment