Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b51c1cc9
Unverified
Commit
b51c1cc9
authored
Mar 29, 2024
by
SangBin Cho
Committed by
GitHub
Mar 28, 2024
Browse files
[2/N] Chunked prefill data update (#3538)
parent
ce567a29
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
272 additions
and
76 deletions
+272
-76
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+13
-1
tests/conftest.py
tests/conftest.py
+4
-0
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+13
-9
tests/test_sequence.py
tests/test_sequence.py
+23
-1
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+20
-17
vllm/config.py
vllm/config.py
+4
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+81
-27
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+15
-5
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+14
-7
vllm/sequence.py
vllm/sequence.py
+55
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+30
-9
No files found.
benchmarks/benchmark_latency.py
View file @
b51c1cc9
...
@@ -26,7 +26,9 @@ def main(args: argparse.Namespace):
...
@@ -26,7 +26,9 @@ def main(args: argparse.Namespace):
kv_cache_dtype
=
args
.
kv_cache_dtype
,
kv_cache_dtype
=
args
.
kv_cache_dtype
,
device
=
args
.
device
,
device
=
args
.
device
,
ray_workers_use_nsight
=
args
.
ray_workers_use_nsight
,
ray_workers_use_nsight
=
args
.
ray_workers_use_nsight
,
download_dir
=
args
.
download_dir
)
enable_chunked_prefill
=
args
.
enable_chunked_prefill
,
download_dir
=
args
.
download_dir
,
block_size
=
args
.
block_size
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
n
=
args
.
n
,
n
=
args
.
n
,
...
@@ -145,6 +147,16 @@ if __name__ == '__main__':
...
@@ -145,6 +147,16 @@ if __name__ == '__main__':
default
=
"cuda"
,
default
=
"cuda"
,
choices
=
[
"cuda"
],
choices
=
[
"cuda"
],
help
=
'device type for vLLM execution, supporting CUDA only currently.'
)
help
=
'device type for vLLM execution, supporting CUDA only currently.'
)
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
16
,
help
=
'block size of key/value cache'
)
parser
.
add_argument
(
'--enable-chunked-prefill'
,
type
=
bool
,
default
=
False
,
help
=
'If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--ray-workers-use-nsight"
,
"--ray-workers-use-nsight"
,
action
=
'store_true'
,
action
=
'store_true'
,
...
...
tests/conftest.py
View file @
b51c1cc9
...
@@ -256,6 +256,8 @@ class VllmRunner:
...
@@ -256,6 +256,8 @@ class VllmRunner:
dtype
:
str
=
"half"
,
dtype
:
str
=
"half"
,
disable_log_stats
:
bool
=
True
,
disable_log_stats
:
bool
=
True
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
block_size
:
int
=
16
,
enable_chunked_prefill
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
self
.
model
=
LLM
(
self
.
model
=
LLM
(
...
@@ -266,6 +268,8 @@ class VllmRunner:
...
@@ -266,6 +268,8 @@ class VllmRunner:
swap_space
=
0
,
swap_space
=
0
,
disable_log_stats
=
disable_log_stats
,
disable_log_stats
=
disable_log_stats
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
block_size
=
block_size
,
enable_chunked_prefill
=
enable_chunked_prefill
,
**
kwargs
,
**
kwargs
,
)
)
...
...
tests/core/test_scheduler.py
View file @
b51c1cc9
...
@@ -10,6 +10,10 @@ from vllm.sequence import Logprob, SequenceGroup
...
@@ -10,6 +10,10 @@ from vllm.sequence import Logprob, SequenceGroup
from
.utils
import
create_dummy_prompt
from
.utils
import
create_dummy_prompt
def
get_sequence_groups
(
scheduler_output
):
return
[
s
.
seq_group
for
s
in
scheduler_output
.
scheduled_seq_groups
]
def
test_scheduler_add_seq_group
():
def
test_scheduler_add_seq_group
():
block_size
=
4
block_size
=
4
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
)
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
)
...
@@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
...
@@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_cpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
cache_config
.
num_gpu_blocks
=
8
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
None
)
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
None
)
running
:
List
[
SequenceGroup
]
=
[]
# Add seq groups to scheduler.
# Add seq groups to scheduler.
running
:
List
[
SequenceGroup
]
=
[]
for
i
in
range
(
num_seq_group
):
for
i
in
range
(
num_seq_group
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
block_size
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
block_size
)
scheduler
.
add_seq_group
(
seq_group
)
scheduler
.
add_seq_group
(
seq_group
)
...
@@ -68,7 +72,7 @@ def test_scheduler_schedule_simple():
...
@@ -68,7 +72,7 @@ def test_scheduler_schedule_simple():
# Schedule seq groups prompts.
# Schedule seq groups prompts.
num_tokens
=
block_size
*
num_seq_group
num_tokens
=
block_size
*
num_seq_group
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
set
(
out
.
scheduled_seq
_groups
)
==
set
(
running
)
assert
set
(
get_sequence
_groups
(
out
)
)
==
set
(
running
)
assert
out
.
num_batched_tokens
==
num_tokens
assert
out
.
num_batched_tokens
==
num_tokens
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
and
not
out
.
blocks_to_swap_out
)
...
@@ -76,7 +80,7 @@ def test_scheduler_schedule_simple():
...
@@ -76,7 +80,7 @@ def test_scheduler_schedule_simple():
# Schedule seq groups generation.
# Schedule seq groups generation.
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
set
(
out
.
scheduled_seq
_groups
)
==
set
(
running
)
assert
set
(
get_sequence
_groups
(
out
)
)
==
set
(
running
)
assert
out
.
num_batched_tokens
==
num_seq_group
assert
out
.
num_batched_tokens
==
num_seq_group
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
and
not
out
.
blocks_to_swap_out
)
...
@@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():
...
@@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups prompts.
# Schedule seq groups prompts.
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
scheduled_seq
_groups
==
[
seq_group_a
,
seq_group_b
]
assert
get_sequence
_groups
(
out
)
==
[
seq_group_a
,
seq_group_b
]
assert
out
.
num_batched_tokens
==
block_size
*
2
# seq_a and seq_b
assert
out
.
num_batched_tokens
==
block_size
*
2
# seq_a and seq_b
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
and
not
out
.
blocks_to_swap_out
)
...
@@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():
...
@@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups generation and preempt seq group b.
# Schedule seq groups generation and preempt seq group b.
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
scheduled_seq
_groups
==
[
seq_group_a
]
assert
get_sequence
_groups
(
out
)
==
[
seq_group_a
]
assert
out
.
num_batched_tokens
==
1
assert
out
.
num_batched_tokens
==
1
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
and
not
out
.
blocks_to_swap_out
)
...
@@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
...
@@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
scheduler
.
abort_seq_group
(
"1"
)
scheduler
.
abort_seq_group
(
"1"
)
seq_group_meta
,
out
=
scheduler
.
schedule
()
seq_group_meta
,
out
=
scheduler
.
schedule
()
assert
out
.
scheduled_seq
_groups
==
[
seq_group_b
]
assert
get_sequence
_groups
(
out
)
==
[
seq_group_b
]
assert
out
.
num_batched_tokens
==
5
# 4 prompt + 1 generation.
assert
out
.
num_batched_tokens
==
5
# 4 prompt + 1 generation.
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
assert
(
not
out
.
blocks_to_copy
and
not
out
.
blocks_to_swap_in
and
not
out
.
blocks_to_swap_out
)
and
not
out
.
blocks_to_swap_out
)
...
@@ -155,11 +159,11 @@ def test_scheduler_max_seqs():
...
@@ -155,11 +159,11 @@ def test_scheduler_max_seqs():
# Schedule seq groups prompts.
# Schedule seq groups prompts.
_
,
out
=
scheduler
.
schedule
()
_
,
out
=
scheduler
.
schedule
()
assert
set
(
out
.
scheduled_seq
_groups
)
==
set
([
all_seq_groups
[
0
]])
assert
set
(
get_sequence
_groups
(
out
)
)
==
set
([
all_seq_groups
[
0
]])
# Schedule seq groups generation.
# Schedule seq groups generation.
_
,
out
=
scheduler
.
schedule
()
_
,
out
=
scheduler
.
schedule
()
assert
set
(
out
.
scheduled_seq
_groups
)
==
set
([
all_seq_groups
[
0
]])
assert
set
(
get_sequence
_groups
(
out
)
)
==
set
([
all_seq_groups
[
0
]])
# Append 2 more seq group
# Append 2 more seq group
scheduler
.
add_seq_group
(
all_seq_groups
[
1
])
scheduler
.
add_seq_group
(
all_seq_groups
[
1
])
...
@@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
...
@@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
# Only 1 seq group should be scheduled since max_seq_group is 2
# Only 1 seq group should be scheduled since max_seq_group is 2
# and one is prompting.
# and one is prompting.
_
,
out
=
scheduler
.
schedule
()
_
,
out
=
scheduler
.
schedule
()
assert
set
(
out
.
scheduled_seq
_groups
)
==
set
([
all_seq_groups
[
1
]])
assert
set
(
get_sequence
_groups
(
out
)
)
==
set
([
all_seq_groups
[
1
]])
def
test_scheduler_delay_factor
():
def
test_scheduler_delay_factor
():
...
...
tests/test_sequence.py
View file @
b51c1cc9
import
pytest
import
pytest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
from
vllm.sequence
import
(
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
...
@@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
sampler_output3
=
SamplerOutput
(
outputs
=
sample_outputs
[:
-
1
])
sampler_output3
=
SamplerOutput
(
outputs
=
sample_outputs
[:
-
1
])
assert
sampler_output1
==
sampler_output2
assert
sampler_output1
==
sampler_output2
assert
sampler_output1
!=
sampler_output3
assert
sampler_output1
!=
sampler_output3
def
test_sequence_data_prefill
():
seq_data
=
SequenceData
(
prompt_token_ids
=
[
1
,
2
,
3
,
4
])
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_computed_tokens
()
==
0
# advance by 2
seq_data
.
update_num_computed_tokens
(
2
)
assert
seq_data
.
get_num_uncomputed_tokens
()
==
2
assert
seq_data
.
get_num_computed_tokens
()
==
2
# advance by 1
seq_data
.
update_num_computed_tokens
(
1
)
assert
seq_data
.
get_num_uncomputed_tokens
()
==
1
assert
seq_data
.
get_num_computed_tokens
()
==
3
# append tokens and reset, simulating recompute
seq_data
.
append_token_id
(
1
,
logprob
=
0.0
)
seq_data
.
reset_num_computed_tokens
()
assert
seq_data
.
get_num_uncomputed_tokens
()
==
5
assert
seq_data
.
get_num_computed_tokens
()
==
0
tests/worker/test_model_runner.py
View file @
b51c1cc9
...
@@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size):
...
@@ -18,15 +18,16 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
seq_data
=
list
(
range
(
prompt_len
))
seq_data
=
SequenceData
(
list
(
range
(
prompt_len
)))
seq_group_metadata_list
.
append
(
seq_group_metadata
=
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
seq_data
)
},
seq_data
=
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
block_tables
=
block_tables
,
))
)
assert
seq_group_metadata
.
token_chunk_size
==
seq_data
.
get_len
()
seq_group_metadata_list
.
append
(
seq_group_metadata
)
expected_selected_token_indices
=
[]
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
...
@@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size):
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
seq_data
=
list
(
range
(
prompt_len
))
seq_data
=
list
(
range
(
prompt_len
))
seq_
group_metadata_list
.
append
(
seq_
data
=
SequenceData
(
seq_data
)
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
is_prompt
=
False
,
seq_data
=
{
0
:
SequenceData
(
seq_data
)
},
seq_data
=
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
)
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_group_metadata_list
.
append
(
seq_group_metadata
)
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
=
(
input_tokens
,
input_positions
,
attn_metadata
,
_
,
_
,
_
=
(
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
...
...
vllm/config.py
View file @
b51c1cc9
...
@@ -533,6 +533,8 @@ class SchedulerConfig:
...
@@ -533,6 +533,8 @@ class SchedulerConfig:
delay_factor: Apply a delay (of delay factor multiplied by previous
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -542,6 +544,7 @@ class SchedulerConfig:
...
@@ -542,6 +544,7 @@ class SchedulerConfig:
max_model_len
:
int
,
max_model_len
:
int
,
use_v2_block_manager
:
bool
=
False
,
use_v2_block_manager
:
bool
=
False
,
delay_factor
:
float
=
0.0
,
delay_factor
:
float
=
0.0
,
enable_chunked_prefill
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
...
@@ -553,6 +556,7 @@ class SchedulerConfig:
...
@@ -553,6 +556,7 @@ class SchedulerConfig:
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
delay_factor
=
delay_factor
self
.
delay_factor
=
delay_factor
self
.
use_v2_block_manager
=
use_v2_block_manager
self
.
use_v2_block_manager
=
use_v2_block_manager
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/scheduler.py
View file @
b51c1cc9
import
enum
import
enum
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
...
@@ -27,11 +28,24 @@ class PreemptionMode(enum.Enum):
...
@@ -27,11 +28,24 @@ class PreemptionMode(enum.Enum):
RECOMPUTE
=
enum
.
auto
()
RECOMPUTE
=
enum
.
auto
()
# seq_group: SequenceGroup to schedule.
# token_chunk_size: The number of prefill tokens to be processed in the next
# step.
@
dataclass
class
ScheduledSequenceGroup
:
# A sequence group that's scheduled.
seq_group
:
SequenceGroup
# The total chunk size (number of tokens) to process for next iteration.
# 1 for decoding. Same as prompt tokens for prefill, but if prefill is
# chunked, it can be smaller than that.
token_chunk_size
:
int
class
SchedulerOutputs
:
class
SchedulerOutputs
:
def
__init__
(
def
__init__
(
self
,
self
,
scheduled_seq_groups
:
Iterable
[
SequenceGroup
],
scheduled_seq_groups
:
Iterable
[
Scheduled
SequenceGroup
],
prompt_run
:
bool
,
prompt_run
:
bool
,
num_batched_tokens
:
int
,
num_batched_tokens
:
int
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
...
@@ -39,17 +53,41 @@ class SchedulerOutputs:
...
@@ -39,17 +53,41 @@ class SchedulerOutputs:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
ignored_seq_groups
:
List
[
SequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
)
->
None
:
)
->
None
:
self
.
scheduled_seq_groups
=
scheduled_seq_groups
"""A list of sequence groups to be scheduled as a single batch.
self
.
prompt_run
=
prompt_run
self
.
num_batched_tokens
=
num_batched_tokens
Args:
self
.
blocks_to_swap_in
=
blocks_to_swap_in
scheduled_seq_groups: A tuple of scheduled sequence group and its
self
.
blocks_to_swap_out
=
blocks_to_swap_out
token chunk size.
self
.
blocks_to_copy
=
blocks_to_copy
prompt_run: True if all sequence groups are in prefill phase.
If False, all sequence groups are in decoding phase.
num_batched_tokens: Total number of batched tokens.
blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
number.
blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
number.
blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
ignored_seq_groups: Sequence groups that are going to be ignored.
"""
# A tuple of scheduled sequence group and its chunk size.
self
.
scheduled_seq_groups
:
ScheduledSequenceGroup
=
scheduled_seq_groups
# True if all sequence groups are in prefill phase. If False, all
# sequence groups are in decoding phase.
self
.
prompt_run
:
bool
=
prompt_run
# Total number of batched tokens.
self
.
num_batched_tokens
:
int
=
num_batched_tokens
# Blocks to swap in. Dict of CPU -> GPU block number.
self
.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
blocks_to_swap_in
# Blocks to swap out. Dict of GPU -> CPU block number.
self
.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
blocks_to_swap_out
# Blocks to copy. Source to a list of dest blocks.
self
.
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
=
blocks_to_copy
# Sequence groups that are going to be ignored.
self
.
ignored_seq_groups
:
List
[
SequenceGroup
]
=
ignored_seq_groups
# Swap in and swap out should never happen at the same time.
# Swap in and swap out should never happen at the same time.
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
self
.
ignored_seq_groups
=
ignored_seq_groups
self
.
num_loras
=
len
(
self
.
lora_requests
)
self
.
num_loras
:
int
=
len
(
self
.
lora_requests
)
if
self
.
num_loras
>
0
:
if
self
.
num_loras
>
0
:
self
.
_sort_by_lora_ids
()
self
.
_sort_by_lora_ids
()
...
@@ -59,13 +97,13 @@ class SchedulerOutputs:
...
@@ -59,13 +97,13 @@ class SchedulerOutputs:
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
def
_sort_by_lora_ids
(
self
)
->
bool
:
def
_sort_by_lora_ids
(
self
)
->
bool
:
self
.
scheduled_seq_groups
=
sorted
(
self
.
scheduled_seq_groups
,
self
.
scheduled_seq_groups
=
sorted
(
key
=
lambda
g
:
self
.
scheduled_seq_groups
,
(
g
.
lora_int_id
,
g
.
request_id
))
key
=
lambda
g
:
(
g
.
seq_group
.
lora_int_id
,
g
.
seq_group
.
request_id
))
@
property
@
property
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
return
{
g
.
lora_request
for
g
in
self
.
scheduled_seq_groups
}
return
{
g
.
seq_group
.
lora_request
for
g
in
self
.
scheduled_seq_groups
}
class
Scheduler
:
class
Scheduler
:
...
@@ -198,11 +236,13 @@ class Scheduler:
...
@@ -198,11 +236,13 @@ class Scheduler:
assert
len
(
waiting_seqs
)
==
1
,
(
assert
len
(
waiting_seqs
)
==
1
,
(
"Waiting sequence group should have only one prompt "
"Waiting sequence group should have only one prompt "
"sequence."
)
"sequence."
)
num_prompt_tokens
=
waiting_seqs
[
0
].
get_len
()
# get_len includes output tokens if the request has been
if
num_prompt_tokens
>
self
.
prompt_limit
:
# preempted.
num_prefill_tokens
=
waiting_seqs
[
0
].
get_len
()
if
num_prefill_tokens
>
self
.
prompt_limit
:
logger
.
warning
(
logger
.
warning
(
f
"Input prompt (
{
num_pr
ompt
_tokens
}
tokens) is too
long
"
f
"Input prompt (
{
num_pr
efill
_tokens
}
tokens) is too "
f
" and exceeds limit of
{
self
.
prompt_limit
}
"
)
f
"
long
and exceeds limit of
{
self
.
prompt_limit
}
"
)
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
...
@@ -215,8 +255,8 @@ class Scheduler:
...
@@ -215,8 +255,8 @@ class Scheduler:
break
break
elif
can_allocate
==
AllocStatus
.
NEVER
:
elif
can_allocate
==
AllocStatus
.
NEVER
:
logger
.
warning
(
logger
.
warning
(
f
"Input prompt (
{
num_pr
ompt
_tokens
}
tokens) is too
long
"
f
"Input prompt (
{
num_pr
efill
_tokens
}
tokens) is too "
f
" and exceeds the capacity of block_manager"
)
f
"
long
and exceeds the capacity of block_manager"
)
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
...
@@ -235,7 +275,7 @@ class Scheduler:
...
@@ -235,7 +275,7 @@ class Scheduler:
continue
continue
# If the number of batched tokens exceeds the limit, stop.
# If the number of batched tokens exceeds the limit, stop.
num_batched_tokens
+=
num_pr
ompt
_tokens
num_batched_tokens
+=
num_pr
efill
_tokens
if
(
num_batched_tokens
>
if
(
num_batched_tokens
>
self
.
scheduler_config
.
max_num_batched_tokens
):
self
.
scheduler_config
.
max_num_batched_tokens
):
break
break
...
@@ -253,8 +293,10 @@ class Scheduler:
...
@@ -253,8 +293,10 @@ class Scheduler:
self
.
_allocate
(
seq_group
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
scheduled
.
append
(
seq_group
)
scheduled
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
num_prefill_tokens
))
self
.
waiting
.
extendleft
(
leftover_waiting_sequences
)
self
.
waiting
.
extendleft
(
leftover_waiting_sequences
)
if
scheduled
or
ignored_seq_groups
:
if
scheduled
or
ignored_seq_groups
:
...
@@ -352,7 +394,11 @@ class Scheduler:
...
@@ -352,7 +394,11 @@ class Scheduler:
for
seq_group
in
self
.
running
)
for
seq_group
in
self
.
running
)
scheduler_outputs
=
SchedulerOutputs
(
scheduler_outputs
=
SchedulerOutputs
(
scheduled_seq_groups
=
self
.
running
,
scheduled_seq_groups
=
[
ScheduledSequenceGroup
(
seq_group
=
running_group
,
token_chunk_size
=
1
)
for
running_group
in
self
.
running
],
prompt_run
=
False
,
prompt_run
=
False
,
num_batched_tokens
=
num_batched_tokens
,
num_batched_tokens
=
num_batched_tokens
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
@@ -371,10 +417,14 @@ class Scheduler:
...
@@ -371,10 +417,14 @@ class Scheduler:
# Create input data structures.
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group
.
maybe_set_first_scheduled_time
(
now
)
# seq_id -> SequenceData
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
# seq_id -> physical block numbers
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
block_tables
:
Dict
[
int
,
List
[
int
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
...
@@ -393,6 +443,7 @@ class Scheduler:
...
@@ -393,6 +443,7 @@ class Scheduler:
seq_data
=
seq_data
,
seq_data
=
seq_data
,
sampling_params
=
seq_group
.
sampling_params
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
state
=
seq_group
.
state
,
state
=
seq_group
.
state
,
...
@@ -409,8 +460,9 @@ class Scheduler:
...
@@ -409,8 +460,9 @@ class Scheduler:
# batch will have been computed before the next scheduling invocation.
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
# will crash the vLLM instance / will not retry.
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
self
.
block_manager
.
mark_blocks_as_computed
(
seq_group
)
self
.
block_manager
.
mark_blocks_as_computed
(
scheduled_seq_group
.
seq_group
)
return
seq_group_metadata_list
,
scheduler_outputs
return
seq_group_metadata_list
,
scheduler_outputs
...
@@ -418,6 +470,7 @@ class Scheduler:
...
@@ -418,6 +470,7 @@ class Scheduler:
self
.
block_manager
.
fork
(
parent_seq
,
child_seq
)
self
.
block_manager
.
fork
(
parent_seq
,
child_seq
)
def
free_seq
(
self
,
seq
:
Sequence
)
->
None
:
def
free_seq
(
self
,
seq
:
Sequence
)
->
None
:
"""Free a sequence from a block table."""
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
def
free_finished_seq_groups
(
self
)
->
None
:
...
@@ -480,7 +533,8 @@ class Scheduler:
...
@@ -480,7 +533,8 @@ class Scheduler:
assert
len
(
seqs
)
==
1
assert
len
(
seqs
)
==
1
for
seq
in
seqs
:
for
seq
in
seqs
:
seq
.
status
=
SequenceStatus
.
WAITING
seq
.
status
=
SequenceStatus
.
WAITING
self
.
block_manager
.
free
(
seq
)
self
.
free_seq
(
seq
)
seq
.
reset_state_for_recompute
()
# NOTE: For FCFS, we insert the preempted sequence group to the front
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
# of the waiting queue.
self
.
waiting
.
appendleft
(
seq_group
)
self
.
waiting
.
appendleft
(
seq_group
)
...
...
vllm/engine/arg_utils.py
View file @
b51c1cc9
...
@@ -62,6 +62,7 @@ class EngineArgs:
...
@@ -62,6 +62,7 @@ class EngineArgs:
image_input_shape
:
Optional
[
str
]
=
None
image_input_shape
:
Optional
[
str
]
=
None
image_feature_size
:
Optional
[
int
]
=
None
image_feature_size
:
Optional
[
int
]
=
None
scheduler_delay_factor
:
float
=
0.0
scheduler_delay_factor
:
float
=
0.0
enable_chunked_prefill
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
...
@@ -356,6 +357,12 @@ class EngineArgs:
...
@@ -356,6 +357,12 @@ class EngineArgs:
default
=
EngineArgs
.
scheduler_delay_factor
,
default
=
EngineArgs
.
scheduler_delay_factor
,
help
=
'Apply a delay (of delay factor multiplied by previous'
help
=
'Apply a delay (of delay factor multiplied by previous'
'prompt latency) before scheduling next prompt.'
)
'prompt latency) before scheduling next prompt.'
)
parser
.
add_argument
(
'--enable-chunked-prefill'
,
type
=
bool
,
default
=
False
,
help
=
'If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens'
)
return
parser
return
parser
@
classmethod
@
classmethod
...
@@ -394,11 +401,14 @@ class EngineArgs:
...
@@ -394,11 +401,14 @@ class EngineArgs:
self
.
tokenizer_pool_type
,
self
.
tokenizer_pool_type
,
self
.
tokenizer_pool_extra_config
,
self
.
tokenizer_pool_extra_config
,
),
self
.
ray_workers_use_nsight
)
),
self
.
ray_workers_use_nsight
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
self
.
max_num_seqs
,
model_config
.
max_model_len
,
model_config
.
max_model_len
,
self
.
use_v2_block_manager
,
self
.
use_v2_block_manager
,
self
.
scheduler_delay_factor
)
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
)
lora_config
=
LoRAConfig
(
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
max_loras
=
self
.
max_loras
,
max_loras
=
self
.
max_loras
,
...
...
vllm/engine/llm_engine.py
View file @
b51c1cc9
...
@@ -553,7 +553,10 @@ class LLMEngine:
...
@@ -553,7 +553,10 @@ class LLMEngine:
# Update the scheduled sequence groups with the model outputs.
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups
=
scheduler_outputs
.
scheduled_seq_groups
scheduled_seq_groups
=
scheduler_outputs
.
scheduled_seq_groups
for
seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
output
):
for
scheduled_seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
output
):
seq_group
=
scheduled_seq_group
.
seq_group
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
# Free the finished sequence groups.
# Free the finished sequence groups.
...
@@ -561,7 +564,8 @@ class LLMEngine:
...
@@ -561,7 +564,8 @@ class LLMEngine:
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
scheduled_seq_groups
:
for
scheduled_seq_group
in
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
...
@@ -676,17 +680,20 @@ class LLMEngine:
...
@@ -676,17 +680,20 @@ class LLMEngine:
# Number of Tokens.
# Number of Tokens.
if
prompt_run
:
if
prompt_run
:
num_prompt_tokens
=
sum
(
num_prompt_tokens
=
sum
(
len
(
seq_group
.
prompt_token_ids
)
len
(
scheduled_seq_group
.
seq_group
.
prompt_token_ids
)
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
num_generation_tokens
=
sum
(
num_generation_tokens
=
sum
(
seq_group
.
num_seqs
()
scheduled_seq_group
.
seq_group
.
num_seqs
()
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
else
:
else
:
num_generation_tokens
=
scheduler_outputs
.
num_batched_tokens
num_generation_tokens
=
scheduler_outputs
.
num_batched_tokens
# Latency Timings.
# Latency Timings.
time_last_iters
=
[]
time_last_iters
=
[]
for
seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
# Time since last token.
# Time since last token.
# (n.b. updates seq_group.metrics.last_token_time)
# (n.b. updates seq_group.metrics.last_token_time)
time_last_iters
.
append
(
seq_group
.
get_last_latency
(
now
))
time_last_iters
.
append
(
seq_group
.
get_last_latency
(
now
))
...
...
vllm/sequence.py
View file @
b51c1cc9
...
@@ -113,6 +113,8 @@ class SequenceData:
...
@@ -113,6 +113,8 @@ class SequenceData:
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
self
.
output_token_ids
=
output_token_ids
self
.
output_token_ids
=
output_token_ids
self
.
cumulative_logprob
=
0.0
self
.
cumulative_logprob
=
0.0
# The number of tokens that are computed (that run against the model).
self
.
_num_computed_tokens
=
0
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
output_token_ids
.
append
(
token_id
)
self
.
output_token_ids
.
append
(
token_id
)
...
@@ -130,6 +132,28 @@ class SequenceData:
...
@@ -130,6 +132,28 @@ class SequenceData:
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
output_token_ids
return
self
.
prompt_token_ids
+
self
.
output_token_ids
def
get_num_computed_tokens
(
self
)
->
int
:
"""Return the number of prefill tokens that are already computed."""
return
self
.
_num_computed_tokens
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
)
->
int
:
"""Update number of tokens computed so far."""
self
.
_num_computed_tokens
+=
num_new_computed_tokens
def
reset_num_computed_tokens
(
self
)
->
None
:
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
the beginning again (e.g., sequence is preempted).
"""
self
.
_num_computed_tokens
=
0
def
get_num_uncomputed_tokens
(
self
)
->
int
:
"""Return the number of prefil tokens that are not computed."""
# we use `get_len()` which includes prompt_len + output_len instead
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
return
self
.
get_len
()
-
self
.
get_num_computed_tokens
()
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
)
->
int
:
if
not
self
.
output_token_ids
:
if
not
self
.
output_token_ids
:
return
self
.
prompt_token_ids
[
-
1
]
return
self
.
prompt_token_ids
[
-
1
]
...
@@ -208,6 +232,10 @@ class Sequence:
...
@@ -208,6 +232,10 @@ class Sequence:
def
num_hashed_tokens_of_block
(
self
,
logical_idx
:
int
):
def
num_hashed_tokens_of_block
(
self
,
logical_idx
:
int
):
return
logical_idx
*
self
.
block_size
+
self
.
block_size
return
logical_idx
*
self
.
block_size
+
self
.
block_size
def
reset_state_for_recompute
(
self
):
"""Reset the sequence states for recomputation."""
self
.
data
.
reset_num_computed_tokens
()
def
_append_logical_block
(
self
)
->
None
:
def
_append_logical_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
block
=
LogicalTokenBlock
(
block_number
=
len
(
self
.
logical_token_blocks
),
block_number
=
len
(
self
.
logical_token_blocks
),
...
@@ -430,6 +458,18 @@ class SequenceGroup:
...
@@ -430,6 +458,18 @@ class SequenceGroup:
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
return
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
is_finished
()]
return
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
is_finished
()]
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
for
seq
in
self
.
seqs_dict
.
values
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
def
get_num_uncomputed_tokens
(
self
)
->
int
:
# All sequences in the group should have the same prompt, so the
# number of unfinished prefill tokens are the same across all
# sequences.
return
list
(
self
.
seqs_dict
.
values
())[
0
].
data
.
get_num_uncomputed_tokens
()
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
def
num_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
)
->
int
:
return
len
(
self
.
get_seqs
(
status
))
return
len
(
self
.
get_seqs
(
status
))
...
@@ -473,6 +513,8 @@ class SequenceGroupMetadata:
...
@@ -473,6 +513,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
block_tables: The block tables. (Seq id -> list of physical block
numbers)
numbers)
token_chunk_size: The number of tokens to be processed. None if
chunking is not required.
state: Internal state tied to this sequence group.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
lora_request: LoRA request.
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
...
@@ -485,6 +527,7 @@ class SequenceGroupMetadata:
...
@@ -485,6 +527,7 @@ class SequenceGroupMetadata:
seq_data
:
Dict
[
int
,
SequenceData
],
seq_data
:
Dict
[
int
,
SequenceData
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
block_tables
:
Dict
[
int
,
List
[
int
]],
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
...
@@ -499,11 +542,23 @@ class SequenceGroupMetadata:
...
@@ -499,11 +542,23 @@ class SequenceGroupMetadata:
self
.
computed_block_nums
=
computed_block_nums
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
_token_chunk_size
=
token_chunk_size
if
self
.
_token_chunk_size
is
None
:
if
is_prompt
:
self
.
_token_chunk_size
=
list
(
seq_data
.
values
())[
0
].
get_len
()
else
:
self
.
_token_chunk_size
=
1
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
@
property
def
token_chunk_size
(
self
)
->
int
:
"""Return the number of tokens to be processed (chunk size)."""
return
self
.
_token_chunk_size
class
SequenceOutput
:
class
SequenceOutput
:
"""The model output associated with a sequence.
"""The model output associated with a sequence.
...
...
vllm/worker/model_runner.py
View file @
b51c1cc9
...
@@ -150,39 +150,58 @@ class ModelRunner:
...
@@ -150,39 +150,58 @@ class ModelRunner:
subquery_lens
:
List
[
int
]
=
[]
subquery_lens
:
List
[
int
]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_id
=
seq_ids
[
0
]
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
computed_block_nums
is
not
None
):
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching "
"now."
)
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
# We should use get_len here because in case of preemption
# it contains output tokens.
prefill_end
=
min
(
seq_data
.
get_len
(),
computed_len
+
token_chunk_size
)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
prefill_end
]
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
len
(
prompt_tokens
)
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert
prefill_end
==
seq_data
.
get_len
()
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
computed_len
=
0
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
computed_block_nums
is
not
None
and
len
(
if
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
:
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
:
# Prefix is not supported with sliding_window
# Prefix is not supported with sliding_window
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
computed_len
:]
prompt_tokens
=
prompt_tokens
[
computed_len
:]
prefix_block_tables
.
append
(
computed_block_nums
)
prefix_block_tables
.
append
(
computed_block_nums
)
context_len
=
computed_len
else
:
else
:
prefix_block_tables
.
append
([])
prefix_block_tables
.
append
([])
context_len
=
0
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert
computed_len
==
0
# actual prompt lens
# actual prompt lens
context_lens
.
append
(
co
ntext
_len
)
context_lens
.
append
(
co
mputed
_len
)
subquery_lens
.
append
(
prompt_len
-
computed_len
)
subquery_lens
.
append
(
prompt_len
-
computed_len
)
input_tokens
.
extend
(
prompt_tokens
)
input_tokens
.
extend
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
# is always the first token in the sequence.
input_positions
.
extend
(
input_positions
.
extend
(
list
(
range
(
computed_len
,
prefill_end
)))
list
(
range
(
computed_len
,
computed_len
+
len
(
prompt_tokens
))))
lora_id
=
seq_group_metadata
.
lora_int_id
lora_id
=
seq_group_metadata
.
lora_int_id
...
@@ -218,7 +237,8 @@ class ModelRunner:
...
@@ -218,7 +237,8 @@ class ModelRunner:
"Prefix caching is currently not supported with "
"Prefix caching is currently not supported with "
"sliding window attention"
)
"sliding window attention"
)
start_idx
=
max
(
0
,
prompt_len
-
self
.
sliding_window
)
start_idx
=
max
(
0
,
prompt_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
prompt_len
):
for
i
in
range
(
computed_len
,
prefill_end
):
if
i
<
start_idx
:
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
continue
...
@@ -331,6 +351,7 @@ class ModelRunner:
...
@@ -331,6 +351,7 @@ class ModelRunner:
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
lora_id
=
seq_group_metadata
.
lora_int_id
lora_id
=
seq_group_metadata
.
lora_int_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment