Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9db93de2
Unverified
Commit
9db93de2
authored
Aug 23, 2024
by
Alexander Matveev
Committed by
GitHub
Aug 23, 2024
Browse files
[Core] Add multi-step support to LLMEngine (#7789)
parent
09c77926
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
195 additions
and
87 deletions
+195
-87
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-1
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+15
-2
tests/lora/test_gemma.py
tests/lora/test_gemma.py
+1
-1
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+0
-0
tests/multi_step/test_correctness_llm.py
tests/multi_step/test_correctness_llm.py
+49
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+2
-72
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+126
-11
No files found.
.buildkite/test-pipeline.yaml
View file @
9db93de2
...
...
@@ -335,7 +335,8 @@ steps:
-
vllm/engine
-
tests/multi_step
commands
:
-
pytest -v -s multi_step/test_correctness.py
-
pytest -v -s multi_step/test_correctness_async_llm.py
-
pytest -v -s multi_step/test_correctness_llm.py
-
label
:
Pipeline Parallelism Test
# 23min
working_dir
:
"
/vllm-workspace/tests"
...
...
benchmarks/benchmark_throughput.py
View file @
9db93de2
...
...
@@ -82,6 +82,8 @@ def run_vllm(
max_num_batched_tokens
:
int
,
distributed_executor_backend
:
Optional
[
str
],
gpu_memory_utilization
:
float
=
0.9
,
num_scheduler_steps
:
int
=
1
,
use_v2_block_manager
:
bool
=
False
,
download_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
EngineArgs
.
load_format
,
)
->
float
:
...
...
@@ -106,6 +108,8 @@ def run_vllm(
max_num_batched_tokens
=
max_num_batched_tokens
,
distributed_executor_backend
=
distributed_executor_backend
,
load_format
=
load_format
,
num_scheduler_steps
=
num_scheduler_steps
,
use_v2_block_manager
=
use_v2_block_manager
,
)
# Add the requests to the engine.
...
...
@@ -232,7 +236,8 @@ def main(args: argparse.Namespace):
args
.
quantization_param_path
,
args
.
device
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
gpu_memory_utilization
,
args
.
download_dir
,
args
.
load_format
)
args
.
gpu_memory_utilization
,
args
.
num_scheduler_steps
,
args
.
use_v2_block_manager
,
args
.
download_dir
,
args
.
load_format
)
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
...
...
@@ -353,10 +358,18 @@ if __name__ == "__main__":
choices
=
[
"auto"
,
"cuda"
,
"cpu"
,
"openvino"
,
"tpu"
,
"xpu"
],
help
=
'device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.'
)
parser
.
add_argument
(
"--num-scheduler-steps"
,
type
=
int
,
default
=
1
,
help
=
"Maximum number of forward steps per scheduler call."
)
parser
.
add_argument
(
"--use-v2-block-manager"
,
action
=
'store_true'
,
help
=
"Enable block manager v2."
)
parser
.
add_argument
(
"--enable-prefix-caching"
,
action
=
'store_true'
,
help
=
"
e
nable automatic prefix caching for vLLM backend."
)
help
=
"
E
nable automatic prefix caching for vLLM backend."
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
'store_true'
,
help
=
"enable chunked prefill for vLLM backend."
)
...
...
tests/lora/test_gemma.py
View file @
9db93de2
...
...
@@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files):
expected_lora_output
=
[
"more important than knowledge.
\n
Author: Albert Einstein
\n
"
,
"everyone else is already taken.
\n
Author: Oscar Wilde
\n
"
,
"so little time
.
\n
Author: Frank Zappa
\n
"
,
"so little time
\n
Author: Frank Zappa
\n
"
,
]
output1
=
do_sample
(
llm
,
gemma_lora_files
,
lora_id
=
1
)
...
...
tests/multi_step/test_correctness.py
→
tests/multi_step/test_correctness
_async_llm
.py
View file @
9db93de2
File moved
tests/multi_step/test_correctness_llm.py
0 → 100644
View file @
9db93de2
# Test the LLMEngine with multi-step-decoding
import
pytest
from
..models.utils
import
check_outputs_equal
MODELS
=
[
"JackFram/llama-160m"
,
]
NUM_SCHEDULER_STEPS
=
[
8
]
# Multi-step decoding steps
NUM_PROMPTS
=
[
10
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
def
test_multi_step_llm
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
tp_size
:
int
,
max_tokens
:
int
,
enforce_eager
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
)
->
None
:
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
prompts
=
prompts
*
((
num_prompts
//
len
(
prompts
))
+
1
)
prompts
=
prompts
[:
num_prompts
]
assert
len
(
prompts
)
==
num_prompts
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
0.7
,
tensor_parallel_size
=
tp_size
,
use_v2_block_manager
=
True
,
num_scheduler_steps
=
num_scheduler_steps
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
vllm/engine/async_llm_engine.py
View file @
9db93de2
import
asyncio
import
time
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
import
torch
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
...
...
@@ -15,7 +13,7 @@ from vllm.core.scheduler import SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
(
DecoderPromptComponents
,
LLMEngine
,
PromptComponents
)
PromptComponents
,
SchedulerOutputState
)
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
...
...
@@ -28,8 +26,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
print_warning_once
...
...
@@ -257,24 +254,11 @@ class RequestTracker:
return
not
self
.
_new_requests
.
empty
()
@
dataclass
class
SchedulerOutputState
:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output
:
Optional
[
SamplerOutput
]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
pipeline_parallel_size
=
\
self
.
parallel_config
.
pipeline_parallel_size
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
pipeline_parallel_size
)
]
async
def
step_async
(
self
,
virtual_engine
:
int
...
...
@@ -367,60 +351,6 @@ class _AsyncLLMEngine(LLMEngine):
return
request_outputs
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
)
->
bool
:
if
(
not
self
.
scheduler_config
.
is_multi_step
or
not
seq_group_metadata_list
):
return
False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps
=
seq_group_metadata_list
[
0
].
state
.
remaining_steps
if
any
([
seq_group
.
state
.
remaining_steps
!=
ref_remaining_steps
for
seq_group
in
seq_group_metadata_list
[
1
:]
]):
raise
AssertionError
((
"All running sequence groups should "
"have the same remaining steps."
))
return
ref_remaining_steps
>
0
def
_cache_scheduler_outputs_for_multi_step
(
self
,
virtual_engine
:
int
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
scheduler_outputs
:
SchedulerOutputs
)
->
None
:
self
.
cached_scheduler_outputs
[
virtual_engine
].
seq_group_metadata_list
=
seq_group_metadata_list
self
.
cached_scheduler_outputs
[
virtual_engine
].
scheduler_outputs
=
\
scheduler_outputs
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
None
def
_get_last_sampled_token_ids
(
self
,
virtual_engine
:
int
)
->
Optional
[
torch
.
Tensor
]:
cached_last_output
=
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
if
(
self
.
scheduler_config
.
is_multi_step
and
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
cached_last_output
is
not
None
and
cached_last_output
.
sampled_token_ids_cpu
is
not
None
):
return
cached_last_output
.
sampled_token_ids_cpu
return
None
def
_update_cached_scheduler_output
(
self
,
virtual_engine
:
int
,
output
:
List
[
Optional
[
SamplerOutput
]])
->
None
:
if
(
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
len
(
output
)
>
0
and
output
[
0
]
is
not
None
):
last_output
=
output
[
-
1
]
assert
last_output
is
not
None
assert
last_output
.
sampled_token_ids_cpu
is
not
None
assert
last_output
.
sampled_token_ids
is
None
assert
last_output
.
sampled_token_probs
is
None
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
last_output
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
...
...
vllm/engine/llm_engine.py
View file @
9db93de2
import
time
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
Union
import
torch
from
typing_extensions
import
TypeVar
,
assert_never
import
vllm.envs
as
envs
...
...
@@ -77,6 +79,14 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional
[
MultiModalDataDict
]]
@
dataclass
class
SchedulerOutputState
:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output
:
Optional
[
SamplerOutput
]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -194,7 +204,7 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"enable_prefix_caching=%s)"
,
"
num_scheduler_steps=%d,
enable_prefix_caching=%s)"
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
...
...
@@ -223,6 +233,7 @@ class LLMEngine:
model_config
.
seed
,
model_config
.
served_model_name
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
num_scheduler_steps
,
cache_config
.
enable_prefix_caching
,
)
# TODO(woosuk): Print more configs in debug mode.
...
...
@@ -380,6 +391,11 @@ class LLMEngine:
),
))
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -1304,16 +1320,40 @@ class LLMEngine:
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise."
)
if
self
.
scheduler_config
.
num_scheduler_steps
>
1
:
raise
NotImplementedError
(
"Multiple scheduler steps (multi-step) are only supported "
"through AsyncLLMEngine. "
)
# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs
=
self
.
cached_scheduler_outputs
[
0
]
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
[
0
].
schedule
()
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self
.
_cache_scheduler_outputs_for_multi_step
(
0
,
seq_group_metadata_list
,
scheduler_outputs
)
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
0
].
get_and_reset_finished_requests_ids
()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
0
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
...
...
@@ -1321,16 +1361,37 @@ class LLMEngine:
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
)
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
0
,
output
)
else
:
output
=
[]
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
for
seq_group
in
seq_group_metadata_list
:
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# clear the cache if we have finished all the steps
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
else
:
request_outputs
=
[]
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
...
@@ -1347,6 +1408,60 @@ class LLMEngine:
return
request_outputs
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
)
->
bool
:
if
(
not
self
.
scheduler_config
.
is_multi_step
or
not
seq_group_metadata_list
):
return
False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps
=
seq_group_metadata_list
[
0
].
state
.
remaining_steps
if
any
([
seq_group
.
state
.
remaining_steps
!=
ref_remaining_steps
for
seq_group
in
seq_group_metadata_list
[
1
:]
]):
raise
AssertionError
((
"All running sequence groups should "
"have the same remaining steps."
))
return
ref_remaining_steps
>
0
def
_cache_scheduler_outputs_for_multi_step
(
self
,
virtual_engine
:
int
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
scheduler_outputs
:
SchedulerOutputs
)
->
None
:
self
.
cached_scheduler_outputs
[
virtual_engine
].
seq_group_metadata_list
=
seq_group_metadata_list
self
.
cached_scheduler_outputs
[
virtual_engine
].
scheduler_outputs
=
\
scheduler_outputs
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
None
def
_update_cached_scheduler_output
(
self
,
virtual_engine
:
int
,
output
:
List
[
Optional
[
SamplerOutput
]])
->
None
:
if
(
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
len
(
output
)
>
0
and
output
[
0
]
is
not
None
):
last_output
=
output
[
-
1
]
assert
last_output
is
not
None
assert
last_output
.
sampled_token_ids_cpu
is
not
None
assert
last_output
.
sampled_token_ids
is
None
assert
last_output
.
sampled_token_probs
is
None
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
last_output
def
_get_last_sampled_token_ids
(
self
,
virtual_engine
:
int
)
->
Optional
[
torch
.
Tensor
]:
cached_last_output
=
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
if
(
self
.
scheduler_config
.
is_multi_step
and
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
cached_last_output
is
not
None
and
cached_last_output
.
sampled_token_ids_cpu
is
not
None
):
return
cached_last_output
.
sampled_token_ids_cpu
return
None
def
add_logger
(
self
,
logger_name
:
str
,
logger
:
StatLoggerBase
)
->
None
:
if
logger_name
in
self
.
stat_loggers
:
raise
KeyError
(
f
"Logger with name
{
logger_name
}
already exists."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment