Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18de8834
Unverified
Commit
18de8834
authored
Apr 06, 2024
by
SangBin Cho
Committed by
GitHub
Apr 05, 2024
Browse files
[Chunked Prefill][4/n] Chunked prefill scheduler. (#3853)
parent
1d7c940d
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1218 additions
and
183 deletions
+1218
-183
requirements-common.txt
requirements-common.txt
+1
-1
tests/core/test_chunked_prefill_scheduler.py
tests/core/test_chunked_prefill_scheduler.py
+563
-0
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+201
-65
tests/test_sequence.py
tests/test_sequence.py
+55
-3
vllm/config.py
vllm/config.py
+2
-1
vllm/core/policy.py
vllm/core/policy.py
+1
-3
vllm/core/scheduler.py
vllm/core/scheduler.py
+345
-95
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-3
vllm/sequence.py
vllm/sequence.py
+48
-11
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+0
-1
No files found.
requirements-common.txt
View file @
18de8834
...
@@ -11,4 +11,4 @@ uvicorn[standard]
...
@@ -11,4 +11,4 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34
# Requires torch >= 2.1.0
outlines == 0.0.34 # Requires torch >= 2.1.0
\ No newline at end of file
tests/core/test_chunked_prefill_scheduler.py
0 → 100644
View file @
18de8834
This diff is collapsed.
Click to expand it.
tests/core/test_scheduler.py
View file @
18de8834
This diff is collapsed.
Click to expand it.
tests/test_sequence.py
View file @
18de8834
import
time
from
typing
import
Optional
import
pytest
import
pytest
from
vllm.sequence
import
(
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
from
vllm
import
SamplingParams
SequenceOutput
)
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
)
def
create_dummy_prompt
(
request_id
:
str
,
prompt_length
:
int
,
block_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
use_beam_search
:
bool
=
False
,
best_of
:
int
=
1
,
)
->
SequenceGroup
:
if
not
block_size
:
block_size
=
prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
prompt
=
Sequence
(
int
(
request_id
),
prompt_str
,
prompt_tokens
,
block_size
)
seq_group
=
SequenceGroup
(
request_id
,
[
prompt
],
SamplingParams
(
use_beam_search
=
use_beam_search
,
best_of
=
best_of
),
time
.
time
(),
lora_request
)
return
seq_group
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -67,6 +96,29 @@ def test_sequence_data_prefill():
...
@@ -67,6 +96,29 @@ def test_sequence_data_prefill():
# append tokens and reset, simulating recompute
# append tokens and reset, simulating recompute
seq_data
.
append_token_id
(
1
,
logprob
=
0.0
)
seq_data
.
append_token_id
(
1
,
logprob
=
0.0
)
seq_data
.
reset_
num_computed_tokens
()
seq_data
.
reset_
state_for_recompute
()
assert
seq_data
.
get_num_uncomputed_tokens
()
==
5
assert
seq_data
.
get_num_uncomputed_tokens
()
==
5
assert
seq_data
.
get_num_computed_tokens
()
==
0
assert
seq_data
.
get_num_computed_tokens
()
==
0
def
test_sequence_group_stage
():
seq_group
=
create_dummy_prompt
(
"1"
,
12
)
assert
seq_group
.
is_prefill
()
is
True
seq_group
.
update_num_computed_tokens
(
6
)
assert
seq_group
.
is_prefill
()
is
True
seq_group
.
update_num_computed_tokens
(
5
)
assert
seq_group
.
is_prefill
()
is
True
seq_group
.
update_num_computed_tokens
(
1
)
assert
seq_group
.
is_prefill
()
is
False
seqs
=
seq_group
.
get_seqs
()
assert
len
(
seqs
)
==
1
seqs
[
0
].
data
.
append_token_id
(
1
,
logprob
=
0.0
)
for
seq
in
seq_group
.
get_seqs
():
seq
.
reset_state_for_recompute
()
assert
seq_group
.
is_prefill
()
is
True
seq_group
.
update_num_computed_tokens
(
5
)
assert
seq_group
.
is_prefill
()
is
True
seq_group
.
update_num_computed_tokens
(
7
)
assert
seq_group
.
is_prefill
()
is
True
seq_group
.
update_num_computed_tokens
(
1
)
assert
seq_group
.
is_prefill
()
is
False
vllm/config.py
View file @
18de8834
...
@@ -576,7 +576,8 @@ class SchedulerConfig:
...
@@ -576,7 +576,8 @@ class SchedulerConfig:
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
if
self
.
max_num_batched_tokens
<
self
.
max_model_len
:
if
(
self
.
max_num_batched_tokens
<
self
.
max_model_len
and
not
self
.
chunked_prefill_enabled
):
raise
ValueError
(
raise
ValueError
(
f
"max_num_batched_tokens (
{
self
.
max_num_batched_tokens
}
) is "
f
"max_num_batched_tokens (
{
self
.
max_num_batched_tokens
}
) is "
f
"smaller than max_model_len (
{
self
.
max_model_len
}
). "
f
"smaller than max_model_len (
{
self
.
max_model_len
}
). "
...
...
vllm/core/policy.py
View file @
18de8834
...
@@ -38,9 +38,7 @@ class FCFS(Policy):
...
@@ -38,9 +38,7 @@ class FCFS(Policy):
class
PolicyFactory
:
class
PolicyFactory
:
_POLICY_REGISTRY
=
{
_POLICY_REGISTRY
=
{
'fcfs'
:
FCFS
}
'fcfs'
:
FCFS
,
}
@
classmethod
@
classmethod
def
get_policy
(
cls
,
policy_name
:
str
,
**
kwargs
)
->
Policy
:
def
get_policy
(
cls
,
policy_name
:
str
,
**
kwargs
)
->
Policy
:
...
...
vllm/core/scheduler.py
View file @
18de8834
This diff is collapsed.
Click to expand it.
vllm/engine/llm_engine.py
View file @
18de8834
...
@@ -607,11 +607,10 @@ class LLMEngine:
...
@@ -607,11 +607,10 @@ class LLMEngine:
now
=
time
.
time
()
now
=
time
.
time
()
# Update the scheduled sequence groups with the model outputs.
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups
=
scheduler_outputs
.
scheduled_seq_groups
scheduled_seq_groups
=
scheduler_outputs
.
scheduled_seq_groups
for
scheduled_seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
output
):
for
scheduled_seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
output
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
scheduled_seq_group
.
token_chunk_size
)
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
# Free the finished sequence groups.
# Free the finished sequence groups.
...
...
vllm/sequence.py
View file @
18de8834
...
@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum):
...
@@ -69,6 +69,11 @@ class SequenceStatus(enum.Enum):
return
finish_reason
return
finish_reason
class
SequenceStage
(
enum
.
Enum
):
PREFILL
=
enum
.
auto
()
DECODE
=
enum
.
auto
()
@
dataclass
@
dataclass
class
RequestMetrics
:
class
RequestMetrics
:
"""Metrics associated with a request.
"""Metrics associated with a request.
...
@@ -115,6 +120,7 @@ class SequenceData:
...
@@ -115,6 +120,7 @@ class SequenceData:
self
.
cumulative_logprob
=
0.0
self
.
cumulative_logprob
=
0.0
# The number of tokens that are computed (that run against the model).
# The number of tokens that are computed (that run against the model).
self
.
_num_computed_tokens
=
0
self
.
_num_computed_tokens
=
0
self
.
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
output_token_ids
.
append
(
token_id
)
self
.
output_token_ids
.
append
(
token_id
)
...
@@ -136,16 +142,22 @@ class SequenceData:
...
@@ -136,16 +142,22 @@ class SequenceData:
"""Return the number of prefill tokens that are already computed."""
"""Return the number of prefill tokens that are already computed."""
return
self
.
_num_computed_tokens
return
self
.
_num_computed_tokens
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
)
->
int
:
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
"""Update number of tokens computed so far."""
self
.
_num_computed_tokens
+=
num_new_computed_tokens
self
.
_num_computed_tokens
+=
num_new_computed_tokens
assert
self
.
_num_computed_tokens
<=
self
.
get_len
(),
(
self
.
_num_computed_tokens
,
self
.
get_len
())
# If all tokens are computed, it means it is in decoding phase.
if
self
.
get_num_uncomputed_tokens
()
==
0
:
self
.
_stage
=
SequenceStage
.
DECODE
def
reset_
num_computed_tokens
(
self
)
->
None
:
def
reset_
state_for_recompute
(
self
)
->
None
:
"""Reset the number of computed tokens from this sequence. It is
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted).
the beginning again (e.g., sequence is preempted).
"""
"""
self
.
_num_computed_tokens
=
0
self
.
_num_computed_tokens
=
0
self
.
_stage
=
SequenceStage
.
PREFILL
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
"""Return the number of prefil tokens that are not computed."""
"""Return the number of prefil tokens that are not computed."""
...
@@ -165,6 +177,10 @@ class SequenceData:
...
@@ -165,6 +177,10 @@ class SequenceData:
def
get_output_token_ids
(
self
)
->
int
:
def
get_output_token_ids
(
self
)
->
int
:
return
self
.
output_token_ids
return
self
.
output_token_ids
@
property
def
stage
(
self
)
->
SequenceStage
:
return
self
.
_stage
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
return
(
f
"SequenceData("
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
...
@@ -234,7 +250,7 @@ class Sequence:
...
@@ -234,7 +250,7 @@ class Sequence:
def
reset_state_for_recompute
(
self
):
def
reset_state_for_recompute
(
self
):
"""Reset the sequence states for recomputation."""
"""Reset the sequence states for recomputation."""
self
.
data
.
reset_
num_computed_tokens
()
self
.
data
.
reset_
state_for_recompute
()
def
_append_logical_block
(
self
)
->
None
:
def
_append_logical_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
block
=
LogicalTokenBlock
(
...
@@ -320,6 +336,23 @@ class Sequence:
...
@@ -320,6 +336,23 @@ class Sequence:
new_seq
.
seq_id
=
new_seq_id
new_seq
.
seq_id
=
new_seq_id
return
new_seq
return
new_seq
def
get_num_new_tokens
(
self
)
->
int
:
"""Get the number of new tokens to be computed.
Args:
remainig_token_budget: The remaining token budgets.
Returns:
The new number of tokens to be computed. I.e., 1 for decode, prompt
size for prefill. If there's not enough remainig_token_budget, it
can return the chunked number of new tokens.
"""
if
self
.
data
.
stage
==
SequenceStage
.
DECODE
:
return
1
return
self
.
data
.
get_num_uncomputed_tokens
()
def
is_prefill
(
self
)
->
bool
:
return
self
.
data
.
stage
==
SequenceStage
.
PREFILL
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"Sequence(seq_id=
{
self
.
seq_id
}
, "
return
(
f
"Sequence(seq_id=
{
self
.
seq_id
}
, "
f
"status=
{
self
.
status
.
name
}
, "
f
"status=
{
self
.
status
.
name
}
, "
...
@@ -461,14 +494,14 @@ class SequenceGroup:
...
@@ -461,14 +494,14 @@ class SequenceGroup:
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
"""Update number of tokens computed so far."""
for
seq
in
self
.
seqs_dict
.
values
():
for
seq
in
self
.
seqs_dict
.
values
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
if
not
seq
.
is_finished
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
# All sequences in the group should have the same prompt, so the
num_uncomputed_tokens
=
0
# number of unfinished prefill tokens are the same across all
for
seq
in
self
.
get_seqs
():
# sequences.
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
return
list
(
return
num_uncomputed_tokens
self
.
seqs_dict
.
values
())[
0
].
data
.
get_num_uncomputed_tokens
()
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
return
len
(
self
.
get_seqs
(
status
))
return
len
(
self
.
get_seqs
(
status
))
...
@@ -497,6 +530,10 @@ class SequenceGroup:
...
@@ -497,6 +530,10 @@ class SequenceGroup:
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_seqs
())
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_seqs
())
def
is_prefill
(
self
)
->
bool
:
# Every sequences should be in the same stage.
return
self
.
get_seqs
()[
0
].
is_prefill
()
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
...
@@ -513,8 +550,8 @@ class SequenceGroupMetadata:
...
@@ -513,8 +550,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
block_tables: The block tables. (Seq id -> list of physical block
numbers)
numbers)
token_chunk_size: The number of tokens to be processed
. None if
token_chunk_size: The number of tokens to be processed
(per sequence).
chunking is not required.
None if
chunking is not required.
state: Internal state tied to this sequence group.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
lora_request: LoRA request.
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
...
...
vllm/worker/model_runner.py
View file @
18de8834
...
@@ -222,7 +222,6 @@ class ModelRunner:
...
@@ -222,7 +222,6 @@ class ModelRunner:
# NOTE(woosuk): Here we assume that the first token in the prompt
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
prefill_end
)))
input_positions
.
extend
(
list
(
range
(
computed_len
,
prefill_end
)))
lora_id
=
seq_group_metadata
.
lora_int_id
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
if
lora_id
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment