Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ecf7b17
Unverified
Commit
2ecf7b17
authored
Aug 14, 2024
by
William Lin
Committed by
GitHub
Aug 14, 2024
Browse files
[core] [3/N] multi-step args and sequence.py (#7452)
parent
3f674a49
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
100 additions
and
5 deletions
+100
-5
vllm/config.py
vllm/config.py
+13
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+5
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+25
-3
vllm/sequence.py
vllm/sequence.py
+57
-1
No files found.
vllm/config.py
View file @
2ecf7b17
...
@@ -847,7 +847,8 @@ class SchedulerConfig:
...
@@ -847,7 +847,8 @@ class SchedulerConfig:
delay_factor
:
float
=
0.0
,
delay_factor
:
float
=
0.0
,
enable_chunked_prefill
:
bool
=
False
,
enable_chunked_prefill
:
bool
=
False
,
embedding_mode
:
Optional
[
bool
]
=
False
,
embedding_mode
:
Optional
[
bool
]
=
False
,
preemption_mode
:
Optional
[
str
]
=
None
)
->
None
:
preemption_mode
:
Optional
[
str
]
=
None
,
num_scheduler_steps
:
int
=
1
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
else
:
else
:
...
@@ -876,6 +877,7 @@ class SchedulerConfig:
...
@@ -876,6 +877,7 @@ class SchedulerConfig:
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
embedding_mode
=
embedding_mode
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
preemption_mode
=
preemption_mode
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
@@ -901,6 +903,16 @@ class SchedulerConfig:
...
@@ -901,6 +903,16 @@ class SchedulerConfig:
f
"(
{
self
.
num_lookahead_slots
}
) must be greater than or "
f
"(
{
self
.
num_lookahead_slots
}
) must be greater than or "
"equal to 0."
)
"equal to 0."
)
if
self
.
num_scheduler_steps
<
1
:
raise
ValueError
(
"num_scheduler_steps "
f
"(
{
self
.
num_scheduler_steps
}
) must be greater than or "
"equal to 1."
)
@
property
def
is_multi_step
(
self
)
->
bool
:
return
self
.
num_scheduler_steps
>
1
class
DeviceConfig
:
class
DeviceConfig
:
device
:
Optional
[
torch
.
device
]
device
:
Optional
[
torch
.
device
]
...
...
vllm/core/scheduler.py
View file @
2ecf7b17
...
@@ -805,6 +805,9 @@ class Scheduler:
...
@@ -805,6 +805,9 @@ class Scheduler:
curr_loras
.
add
(
lora_int_id
)
curr_loras
.
add
(
lora_int_id
)
waiting_queue
.
popleft
()
waiting_queue
.
popleft
()
self
.
_allocate_and_set_running
(
seq_group
)
self
.
_allocate_and_set_running
(
seq_group
)
seq_group
.
init_multi_step
(
num_scheduler_steps
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
)
+
1
)
seq_groups
.
append
(
seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
num_new_tokens
))
token_chunk_size
=
num_new_tokens
))
...
@@ -1108,6 +1111,7 @@ class Scheduler:
...
@@ -1108,6 +1111,7 @@ class Scheduler:
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
encoder_seq_data
=
encoder_seq_data
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
cross_block_table
=
cross_block_table
,
state
=
seq_group
.
state
,
# `multi_modal_data` will only be present for the 1st comm
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# between engine and worker.
# the subsequent comms can still use delta, but
# the subsequent comms can still use delta, but
...
@@ -1184,6 +1188,7 @@ class Scheduler:
...
@@ -1184,6 +1188,7 @@ class Scheduler:
slots.
slots.
"""
"""
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
seq_group
.
init_multi_step
(
num_scheduler_steps
=
num_lookahead_slots
+
1
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
...
...
vllm/engine/arg_utils.py
View file @
2ecf7b17
...
@@ -115,6 +115,7 @@ class EngineArgs:
...
@@ -115,6 +115,7 @@ class EngineArgs:
lora_dtype
:
str
=
'auto'
lora_dtype
:
str
=
'auto'
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'auto'
device
:
str
=
'auto'
num_scheduler_steps
:
int
=
1
ray_workers_use_nsight
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_lookahead_slots
:
int
=
0
num_lookahead_slots
:
int
=
0
...
@@ -543,6 +544,11 @@ class EngineArgs:
...
@@ -543,6 +544,11 @@ class EngineArgs:
"tpu"
,
"xpu"
"tpu"
,
"xpu"
],
],
help
=
'Device type for vLLM execution.'
)
help
=
'Device type for vLLM execution.'
)
parser
.
add_argument
(
'--num-scheduler-steps'
,
type
=
int
,
default
=
1
,
help
=
(
'Maximum number of forward steps per '
'scheduler call.'
))
parser
.
add_argument
(
parser
.
add_argument
(
'--scheduler-delay-factor'
,
'--scheduler-delay-factor'
,
...
@@ -858,18 +864,34 @@ class EngineArgs:
...
@@ -858,18 +864,34 @@ class EngineArgs:
disable_logprobs
=
self
.
disable_logprobs_during_spec_decoding
,
disable_logprobs
=
self
.
disable_logprobs_during_spec_decoding
,
)
)
if
self
.
num_scheduler_steps
>
1
:
raise
NotImplementedError
(
"Multi-step is not yet supported."
)
if
speculative_config
is
not
None
:
raise
ValueError
(
"Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)"
)
if
self
.
enable_chunked_prefill
:
raise
ValueError
(
"Chunked prefill is not supported with "
"multi-step (--num-scheduler-steps > 1)"
)
# make sure num_lookahead_slots is set the higher value depending on
# if we are using speculative decoding or multi-step
num_lookahead_slots
=
max
(
self
.
num_lookahead_slots
,
self
.
num_scheduler_steps
-
1
)
num_lookahead_slots
=
num_lookahead_slots
\
if
speculative_config
is
None
\
else
speculative_config
.
num_lookahead_slots
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
max_num_seqs
=
self
.
max_num_seqs
,
max_num_seqs
=
self
.
max_num_seqs
,
max_model_len
=
model_config
.
max_model_len
,
max_model_len
=
model_config
.
max_model_len
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
num_lookahead_slots
=
(
self
.
num_lookahead_slots
num_lookahead_slots
=
num_lookahead_slots
,
if
speculative_config
is
None
else
speculative_config
.
num_lookahead_slots
),
delay_factor
=
self
.
scheduler_delay_factor
,
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
embedding_mode
=
model_config
.
embedding_mode
,
embedding_mode
=
model_config
.
embedding_mode
,
preemption_mode
=
self
.
preemption_mode
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
)
)
lora_config
=
LoRAConfig
(
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
...
...
vllm/sequence.py
View file @
2ecf7b17
...
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
...
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
,
cast
)
Union
,
cast
)
import
numpy
import
torch
import
torch
from
vllm.inputs.parse
import
is_valid_encoder_decoder_llm_inputs
from
vllm.inputs.parse
import
is_valid_encoder_decoder_llm_inputs
...
@@ -489,6 +490,19 @@ class Sequence:
...
@@ -489,6 +490,19 @@ class Sequence:
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
@
dataclass
class
SequenceGroupState
:
"""Mutable state tied to a specific sequence group"""
# for multi-step decoding
num_steps
:
int
=
1
current_step
:
int
=
0
@
property
def
remaining_steps
(
self
)
->
int
:
return
self
.
num_steps
-
self
.
current_step
class
SequenceGroup
:
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
"""A group of sequences that are generated from the same prompt.
...
@@ -534,6 +548,7 @@ class SequenceGroup:
...
@@ -534,6 +548,7 @@ class SequenceGroup:
time_in_queue
=
None
)
time_in_queue
=
None
)
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
embeddings
=
embeddings
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
self
.
pooling_params
=
pooling_params
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
...
@@ -588,6 +603,10 @@ class SequenceGroup:
...
@@ -588,6 +603,10 @@ class SequenceGroup:
return
self
.
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
\
return
self
.
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
\
if
self
.
prompt_adapter_request
else
0
if
self
.
prompt_adapter_request
else
0
def
init_multi_step
(
self
,
num_scheduler_steps
:
int
)
->
None
:
self
.
state
.
num_steps
=
num_scheduler_steps
self
.
state
.
current_step
=
0
def
get_last_latency
(
self
,
now
:
float
)
->
Optional
[
float
]:
def
get_last_latency
(
self
,
now
:
float
)
->
Optional
[
float
]:
"""Sets the last token time for Request level timings."""
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
# If still in prefill phase, raise Error.
...
@@ -756,6 +775,7 @@ class SequenceGroupMetadata:
...
@@ -756,6 +775,7 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
(SequenceGroup.encoder_seq). Should be None
...
@@ -781,6 +801,7 @@ class SequenceGroupMetadata:
...
@@ -781,6 +801,7 @@ class SequenceGroupMetadata:
token_chunk_size
:
Optional
[
int
]
=
None
,
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
...
@@ -796,6 +817,7 @@ class SequenceGroupMetadata:
...
@@ -796,6 +817,7 @@ class SequenceGroupMetadata:
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
computed_block_nums
=
computed_block_nums
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
encoder_seq_data
=
encoder_seq_data
self
.
encoder_seq_data
=
encoder_seq_data
self
.
cross_block_table
=
cross_block_table
self
.
cross_block_table
=
cross_block_table
self
.
_token_chunk_size
=
token_chunk_size
self
.
_token_chunk_size
=
token_chunk_size
...
@@ -834,6 +856,10 @@ class SequenceGroupMetadata:
...
@@ -834,6 +856,10 @@ class SequenceGroupMetadata:
assert
self
.
_token_chunk_size
is
not
None
assert
self
.
_token_chunk_size
is
not
None
return
self
.
_token_chunk_size
return
self
.
_token_chunk_size
def
finish_step
(
self
)
->
None
:
assert
self
.
state
.
current_step
<
self
.
state
.
num_steps
self
.
state
.
current_step
+=
1
class
SequenceOutput
:
class
SequenceOutput
:
"""The model output associated with a sequence.
"""The model output associated with a sequence.
...
@@ -971,6 +997,7 @@ class SamplerOutput:
...
@@ -971,6 +997,7 @@ class SamplerOutput:
# On-device tensor containing the sampled token ids.
# On-device tensor containing the sampled token ids.
sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
sampled_token_ids_numpy
:
Optional
[
numpy
.
ndarray
]
=
None
# Spec decode metrics populated by workers.
# Spec decode metrics populated by workers.
spec_decode_worker_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
spec_decode_worker_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
...
@@ -1112,6 +1139,33 @@ class ExecuteModelRequest:
...
@@ -1112,6 +1139,33 @@ class ExecuteModelRequest:
num_steps
:
int
=
1
num_steps
:
int
=
1
# Finished request ids since last step.
# Finished request ids since last step.
finished_requests_ids
:
List
[
str
]
=
field
(
default_factory
=
list
)
finished_requests_ids
:
List
[
str
]
=
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
is_first_multi_step
(
self
)
->
bool
:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert
len
(
self
.
seq_group_metadata_list
)
>
0
first_seq_group
=
self
.
seq_group_metadata_list
[
0
]
return
first_seq_group
.
state
.
current_step
==
0
@
property
def
is_last_step
(
self
)
->
bool
:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert
len
(
self
.
seq_group_metadata_list
)
>
0
first_seq_group
=
self
.
seq_group_metadata_list
[
0
]
num_steps
=
first_seq_group
.
state
.
num_steps
current_step
=
first_seq_group
.
state
.
current_step
return
num_steps
-
current_step
==
1
@
property
def
current_step
(
self
)
->
int
:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert
len
(
self
.
seq_group_metadata_list
)
>
0
return
self
.
seq_group_metadata_list
[
0
].
state
.
current_step
def
clone
(
def
clone
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
...
@@ -1127,4 +1181,6 @@ class ExecuteModelRequest:
...
@@ -1127,4 +1181,6 @@ class ExecuteModelRequest:
running_queue_size
=
self
.
running_queue_size
,
running_queue_size
=
self
.
running_queue_size
,
previous_hidden_states
=
self
.
previous_hidden_states
,
previous_hidden_states
=
self
.
previous_hidden_states
,
num_steps
=
self
.
num_steps
,
num_steps
=
self
.
num_steps
,
finished_requests_ids
=
self
.
finished_requests_ids
)
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment