Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2eedede8
Unverified
Commit
2eedede8
authored
Aug 26, 2024
by
Megha Agarwal
Committed by
GitHub
Aug 26, 2024
Browse files
[Core] Asynchronous Output Processor (#7049)
Co-authored-by:
Alexander Matveev
<
alexm@neuralmagic.com
>
parent
015e6cc2
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
629 additions
and
197 deletions
+629
-197
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+9
-1
tests/basic_correctness/test_chunked_prefill.py
tests/basic_correctness/test_chunked_prefill.py
+6
-0
tests/basic_correctness/test_preemption.py
tests/basic_correctness/test_preemption.py
+0
-1
tests/core/test_chunked_prefill_scheduler.py
tests/core/test_chunked_prefill_scheduler.py
+2
-2
tests/core/utils.py
tests/core/utils.py
+1
-1
tests/engine/test_stop_strings.py
tests/engine/test_stop_strings.py
+103
-52
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+3
-0
vllm/config.py
vllm/config.py
+53
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+118
-12
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+47
-13
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+208
-70
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+5
-8
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+16
-9
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+20
-12
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+10
-1
vllm/executor/distributed_gpu_executor.py
vllm/executor/distributed_gpu_executor.py
+4
-3
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+1
-1
vllm/sequence.py
vllm/sequence.py
+8
-8
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+7
-3
No files found.
benchmarks/benchmark_throughput.py
View file @
2eedede8
...
@@ -86,6 +86,7 @@ def run_vllm(
...
@@ -86,6 +86,7 @@ def run_vllm(
use_v2_block_manager
:
bool
=
False
,
use_v2_block_manager
:
bool
=
False
,
download_dir
:
Optional
[
str
]
=
None
,
download_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
EngineArgs
.
load_format
,
load_format
:
str
=
EngineArgs
.
load_format
,
disable_async_output_proc
:
bool
=
False
,
)
->
float
:
)
->
float
:
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
llm
=
LLM
(
llm
=
LLM
(
...
@@ -110,6 +111,7 @@ def run_vllm(
...
@@ -110,6 +111,7 @@ def run_vllm(
load_format
=
load_format
,
load_format
=
load_format
,
num_scheduler_steps
=
num_scheduler_steps
,
num_scheduler_steps
=
num_scheduler_steps
,
use_v2_block_manager
=
use_v2_block_manager
,
use_v2_block_manager
=
use_v2_block_manager
,
disable_async_output_proc
=
disable_async_output_proc
,
)
)
# Add the requests to the engine.
# Add the requests to the engine.
...
@@ -237,7 +239,8 @@ def main(args: argparse.Namespace):
...
@@ -237,7 +239,8 @@ def main(args: argparse.Namespace):
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
gpu_memory_utilization
,
args
.
num_scheduler_steps
,
args
.
gpu_memory_utilization
,
args
.
num_scheduler_steps
,
args
.
use_v2_block_manager
,
args
.
download_dir
,
args
.
load_format
)
args
.
use_v2_block_manager
,
args
.
download_dir
,
args
.
load_format
,
args
.
disable_async_output_proc
)
elif
args
.
backend
==
"hf"
:
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
...
@@ -418,6 +421,11 @@ if __name__ == "__main__":
...
@@ -418,6 +421,11 @@ if __name__ == "__main__":
'section for more information.
\n
'
'section for more information.
\n
'
'* "bitsandbytes" will load the weights using bitsandbytes '
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.
\n
'
)
'quantization.
\n
'
)
parser
.
add_argument
(
"--disable-async-output-proc"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Disable async output processor for vLLM backend."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
tokenizer
is
None
:
if
args
.
tokenizer
is
None
:
args
.
tokenizer
=
args
.
model
args
.
tokenizer
=
args
.
model
...
...
tests/basic_correctness/test_chunked_prefill.py
View file @
2eedede8
...
@@ -88,6 +88,9 @@ def test_models(
...
@@ -88,6 +88,9 @@ def test_models(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
# reset distributed env properly. Use a value > 1 just when you test.
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
])
# Due to low-precision numerical divergence, this test is too sensitive to
# the async postprocessor
@
pytest
.
mark
.
parametrize
(
"disable_async_output_proc"
,
[
True
])
def
test_models_with_fp8_kv_cache
(
def
test_models_with_fp8_kv_cache
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
...
@@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache(
...
@@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache(
chunked_prefill_token_size
:
int
,
chunked_prefill_token_size
:
int
,
enforce_eager
:
bool
,
enforce_eager
:
bool
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
disable_async_output_proc
:
bool
,
)
->
None
:
)
->
None
:
"""
"""
Only checks log probs match between chunked-prefill and
Only checks log probs match between chunked-prefill and
...
@@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache(
...
@@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager
=
enforce_eager
,
enforce_eager
=
enforce_eager
,
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
disable_async_output_proc
=
disable_async_output_proc
,
**
extra_kwargs
,
**
extra_kwargs
,
)
as
vllm_model
:
)
as
vllm_model
:
no_chunked_prefill_outputs
=
vllm_model
.
generate_greedy_logprobs
(
no_chunked_prefill_outputs
=
vllm_model
.
generate_greedy_logprobs
(
...
@@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache(
...
@@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager
=
enforce_eager
,
enforce_eager
=
enforce_eager
,
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
disable_async_output_proc
=
disable_async_output_proc
,
**
extra_kwargs
,
**
extra_kwargs
,
)
as
vllm_model
:
)
as
vllm_model
:
chunked_prefill_outputs
=
vllm_model
.
generate_greedy_logprobs
(
chunked_prefill_outputs
=
vllm_model
.
generate_greedy_logprobs
(
...
...
tests/basic_correctness/test_preemption.py
View file @
2eedede8
...
@@ -209,7 +209,6 @@ def test_swap_infeasible(
...
@@ -209,7 +209,6 @@ def test_swap_infeasible(
prefill_blocks
=
2
prefill_blocks
=
2
decode_blocks
=
max_tokens
//
BLOCK_SIZE
decode_blocks
=
max_tokens
//
BLOCK_SIZE
example_prompts
=
example_prompts
[:
1
]
example_prompts
=
example_prompts
[:
1
]
with
vllm_runner
(
with
vllm_runner
(
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
...
...
tests/core/test_chunked_prefill_scheduler.py
View file @
2eedede8
...
@@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
...
@@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
def
schedule_and_update_computed_tokens
(
scheduler
):
def
schedule_and_update_computed_tokens
(
scheduler
):
metas
,
out
=
scheduler
.
schedule
()
metas
,
out
,
_
=
scheduler
.
schedule
()
for
s
,
meta
in
zip
(
out
.
scheduled_seq_groups
,
metas
):
for
s
,
meta
in
zip
(
out
.
scheduled_seq_groups
,
metas
):
s
.
seq_group
.
update_num_computed_tokens
(
meta
.
token_chunk_size
)
s
.
seq_group
.
update_num_computed_tokens
(
meta
.
token_chunk_size
)
return
metas
,
out
return
metas
,
out
...
@@ -180,7 +180,7 @@ def test_maximal_decoding():
...
@@ -180,7 +180,7 @@ def test_maximal_decoding():
"""Verify decoding requests are prioritized."""
"""Verify decoding requests are prioritized."""
block_size
=
4
block_size
=
4
max_seqs
=
2
max_seqs
=
2
max_model_len
=
2
max_model_len
=
8
max_num_batched_tokens
=
2
max_num_batched_tokens
=
2
scheduler_config
=
SchedulerConfig
(
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
max_num_batched_tokens
,
max_seqs
,
max_seqs
,
...
...
tests/core/utils.py
View file @
2eedede8
...
@@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):
...
@@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):
def
schedule_and_update_computed_tokens
(
scheduler
):
def
schedule_and_update_computed_tokens
(
scheduler
):
metas
,
out
=
scheduler
.
schedule
()
metas
,
out
,
_
=
scheduler
.
schedule
()
for
s
,
meta
in
zip
(
out
.
scheduled_seq_groups
,
metas
):
for
s
,
meta
in
zip
(
out
.
scheduled_seq_groups
,
metas
):
s
.
seq_group
.
update_num_computed_tokens
(
meta
.
token_chunk_size
)
s
.
seq_group
.
update_num_computed_tokens
(
meta
.
token_chunk_size
)
return
metas
,
out
return
metas
,
out
...
...
tests/engine/test_stop_strings.py
View file @
2eedede8
...
@@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams
...
@@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams
MODEL
=
"meta-llama/llama-2-7b-hf"
MODEL
=
"meta-llama/llama-2-7b-hf"
MAX_TOKENS
=
200
MAX_TOKENS
=
200
IS_ASYNC
=
False
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
vllm_model
(
vllm_runner
):
def
vllm_model
(
vllm_runner
):
...
@@ -14,99 +16,148 @@ def vllm_model(vllm_runner):
...
@@ -14,99 +16,148 @@ def vllm_model(vllm_runner):
yield
vllm_model
yield
vllm_model
@
pytest
.
mark
.
skip_global_cleanup
def
_test_stopping
(
llm_engine
:
LLMEngine
,
def
test_stop_basic
(
vllm_model
):
expected_output
:
str
,
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
expected_reason
:
Any
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_in_output
:
bool
=
False
,
use_async_output_proc
:
bool
=
False
)
->
None
:
llm_engine
.
add_request
(
"id"
,
"A story about vLLM:
\n
"
,
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
MAX_TOKENS
,
stop
=
stop
,
stop_token_ids
=
stop_token_ids
,
include_stop_str_in_output
=
include_in_output
,
),
None
)
output
:
Optional
[
CompletionOutput
]
=
None
output_text
=
""
stop_reason
=
None
if
use_async_output_proc
:
llm_engine
.
step
()
while
llm_engine
.
has_unfinished_requests
():
(
request_output
,
)
=
llm_engine
.
step
()
(
output
,
)
=
request_output
.
outputs
# Ensure we don't backtrack
assert
output
.
text
.
startswith
(
output_text
)
output_text
=
output
.
text
stop_reason
=
output
.
stop_reason
assert
output
is
not
None
assert
output_text
==
expected_output
assert
stop_reason
==
expected_reason
def
_set_async_mode
(
llm_engine
,
is_async
):
llm_engine
.
scheduler
[
0
].
use_async_output_proc
=
is_async
def
_stop_basic
(
llm_engine
,
is_async
):
_test_stopping
(
llm_engine
,
stop
=
[
"."
],
stop
=
[
"."
],
include_in_output
=
False
,
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer organization"
,
expected_output
=
"VLLM is a 100% volunteer organization"
,
expected_reason
=
"."
)
expected_reason
=
"."
,
use_async_output_proc
=
is_async
)
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
_test_stopping
(
llm_engine
,
stop
=
[
"."
],
stop
=
[
"."
],
include_in_output
=
True
,
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organization."
,
expected_output
=
"VLLM is a 100% volunteer organization."
,
expected_reason
=
"."
)
expected_reason
=
"."
,
use_async_output_proc
=
is_async
)
@
pytest
.
mark
.
skip_global_cleanup
def
_stop_multi_tokens
(
llm_engine
,
is_async
):
def
test_stop_multi_tokens
(
vllm_model
):
_test_stopping
(
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
llm_engine
,
stop
=
[
"group of peo"
,
"short"
],
stop
=
[
"group of peo"
,
"short"
],
include_in_output
=
False
,
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer organization. We are a "
,
expected_output
=
"VLLM is a 100% volunteer organization. We are a "
,
expected_reason
=
"group of peo"
)
expected_reason
=
"group of peo"
,
use_async_output_proc
=
is_async
)
_test_stopping
(
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
llm_engine
,
stop
=
[
"group of peo"
,
"short"
],
stop
=
[
"group of peo"
,
"short"
],
include_in_output
=
True
,
include_in_output
=
True
,
expected_output
=
expected_output
=
"VLLM is a 100% volunteer organization. We are a group of peo"
,
"VLLM is a 100% volunteer organization. We are a group of peo"
,
expected_reason
=
"group of peo"
)
expected_reason
=
"group of peo"
,
use_async_output_proc
=
is_async
)
@
pytest
.
mark
.
skip_global_cleanup
def
_stop_partial_token
(
llm_engine
,
is_async
):
def
test_stop_partial_token
(
vllm_model
):
_test_stopping
(
llm_engine
,
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
stop
=
[
"gani"
],
stop
=
[
"gani"
],
include_in_output
=
False
,
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer or"
,
expected_output
=
"VLLM is a 100% volunteer or"
,
expected_reason
=
"gani"
)
expected_reason
=
"gani"
,
use_async_output_proc
=
is_async
)
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
_test_stopping
(
llm_engine
,
stop
=
[
"gani"
],
stop
=
[
"gani"
],
include_in_output
=
True
,
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organi"
,
expected_output
=
"VLLM is a 100% volunteer organi"
,
expected_reason
=
"gani"
)
expected_reason
=
"gani"
,
use_async_output_proc
=
is_async
)
@
pytest
.
mark
.
skip_global_cleanup
def
_stop_token_id
(
llm_engine
,
is_async
):
def
test_stop_token_id
(
vllm_model
):
# token id 13013 => " organization"
# token id 13013 => " organization"
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
_test_stopping
(
llm_engine
,
stop_token_ids
=
[
13013
],
stop_token_ids
=
[
13013
],
include_in_output
=
False
,
include_in_output
=
False
,
expected_output
=
"VLLM is a 100% volunteer"
,
expected_output
=
"VLLM is a 100% volunteer"
,
expected_reason
=
13013
)
expected_reason
=
13013
,
use_async_output_proc
=
is_async
)
_test_stopping
(
vllm_model
.
model
.
llm_engine
,
_test_stopping
(
llm_engine
,
stop_token_ids
=
[
13013
],
stop_token_ids
=
[
13013
],
include_in_output
=
True
,
include_in_output
=
True
,
expected_output
=
"VLLM is a 100% volunteer organization"
,
expected_output
=
"VLLM is a 100% volunteer organization"
,
expected_reason
=
13013
)
expected_reason
=
13013
,
use_async_output_proc
=
is_async
)
def
_test_stopping
(
llm_engine
:
LLMEngine
,
@
pytest
.
mark
.
skip_global_cleanup
expected_output
:
str
,
def
test_stop_basic
(
vllm_model
):
expected_reason
:
Any
,
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
True
)
stop
:
Optional
[
List
[
str
]]
=
None
,
_stop_basic
(
vllm_model
.
model
.
llm_engine
,
is_async
=
True
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_in_output
:
bool
=
False
)
->
None
:
llm_engine
.
add_request
(
"id"
,
"A story about vLLM:
\n
"
,
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
MAX_TOKENS
,
stop
=
stop
,
stop_token_ids
=
stop_token_ids
,
include_stop_str_in_output
=
include_in_output
,
),
None
)
output
:
Optional
[
CompletionOutput
]
=
None
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
False
)
output_text
=
""
_stop_basic
(
vllm_model
.
model
.
llm_engine
,
is_async
=
False
)
stop_reason
=
None
while
llm_engine
.
has_unfinished_requests
():
(
request_output
,
)
=
llm_engine
.
step
()
(
output
,
)
=
request_output
.
outputs
# Ensure we don't backtrack
assert
output
.
text
.
startswith
(
output_text
)
output_text
=
output
.
text
stop_reason
=
output
.
stop_reason
assert
output
is
not
None
@
pytest
.
mark
.
skip_global_cleanup
assert
output_text
==
expected_output
def
test_stop_multi_tokens
(
vllm_model
):
assert
stop_reason
==
expected_reason
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
True
)
_stop_multi_tokens
(
vllm_model
.
model
.
llm_engine
,
is_async
=
True
)
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
False
)
_stop_multi_tokens
(
vllm_model
.
model
.
llm_engine
,
is_async
=
False
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_partial_token
(
vllm_model
):
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
True
)
_stop_partial_token
(
vllm_model
.
model
.
llm_engine
,
is_async
=
True
)
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
False
)
_stop_partial_token
(
vllm_model
.
model
.
llm_engine
,
is_async
=
False
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_stop_token_id
(
vllm_model
):
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
True
)
_stop_token_id
(
vllm_model
.
model
.
llm_engine
,
is_async
=
True
)
_set_async_mode
(
vllm_model
.
model
.
llm_engine
,
False
)
_stop_token_id
(
vllm_model
.
model
.
llm_engine
,
is_async
=
False
)
tests/multi_step/test_correctness_async_llm.py
View file @
2eedede8
...
@@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
...
@@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args
=
DEFAULT_SERVER_ARGS
+
\
ms_server_args
=
DEFAULT_SERVER_ARGS
+
\
[
"--num-scheduler-steps"
,
f
"
{
num_scheduler_steps
}
"
]
[
"--num-scheduler-steps"
,
f
"
{
num_scheduler_steps
}
"
]
# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args
+=
[
"--disable-async-output-proc"
]
if
eager_mode
:
if
eager_mode
:
ms_server_args
.
append
(
"--enforce-eager"
)
ms_server_args
.
append
(
"--enforce-eager"
)
...
...
vllm/config.py
View file @
2eedede8
...
@@ -140,6 +140,7 @@ class ModelConfig:
...
@@ -140,6 +140,7 @@ class ModelConfig:
skip_tokenizer_init
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
use_async_output_proc
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -172,6 +173,7 @@ class ModelConfig:
...
@@ -172,6 +173,7 @@ class ModelConfig:
self
.
hf_image_processor_config
=
get_hf_image_processor_config
(
self
.
hf_image_processor_config
=
get_hf_image_processor_config
(
self
.
model
,
revision
)
self
.
model
,
revision
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
use_async_output_proc
=
use_async_output_proc
# Choose a default enforce_eager value if the user did not specify
# Choose a default enforce_eager value if the user did not specify
# a value (enforce_eager is None)
# a value (enforce_eager is None)
...
@@ -326,6 +328,49 @@ class ModelConfig:
...
@@ -326,6 +328,49 @@ class ModelConfig:
self
.
max_seq_len_to_capture
=
min
(
self
.
max_seq_len_to_capture
,
self
.
max_seq_len_to_capture
=
min
(
self
.
max_seq_len_to_capture
,
self
.
max_model_len
)
self
.
max_model_len
)
def
verify_async_output_proc
(
self
,
parallel_config
,
speculative_config
,
device_config
)
->
None
:
if
not
self
.
use_async_output_proc
:
# Nothing to check
return
if
parallel_config
.
pipeline_parallel_size
>
1
:
logger
.
warning
(
"Async output processing can not be enabled "
"with pipeline parallel"
)
self
.
use_async_output_proc
=
False
return
if
device_config
.
device_type
!=
"cuda"
:
logger
.
warning
(
"Async output processing is only supported for CUDA."
" Disabling it for other platforms."
)
self
.
use_async_output_proc
=
False
return
if
envs
.
VLLM_USE_RAY_SPMD_WORKER
:
logger
.
warning
(
"Async output processing can not be enabled with ray spmd"
)
self
.
use_async_output_proc
=
False
return
if
self
.
enforce_eager
:
logger
.
warning
(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used"
)
self
.
use_async_output_proc
=
not
self
.
enforce_eager
return
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if
self
.
embedding_mode
:
self
.
use_async_output_proc
=
False
if
speculative_config
:
logger
.
warning
(
"Async output processing is not supported with"
" speculative decoding currently."
)
self
.
use_async_output_proc
=
False
def
verify_with_parallel_config
(
def
verify_with_parallel_config
(
self
,
self
,
parallel_config
:
"ParallelConfig"
,
parallel_config
:
"ParallelConfig"
,
...
@@ -358,6 +403,11 @@ class ModelConfig:
...
@@ -358,6 +403,11 @@ class ModelConfig:
"fallback to the eager mode."
)
"fallback to the eager mode."
)
self
.
enforce_eager
=
True
self
.
enforce_eager
=
True
if
pipeline_parallel_size
>
1
and
self
.
use_async_output_proc
:
logger
.
warning
(
"Async output processor is not supported with "
"pipeline parallelism currently. Disabling it."
)
self
.
use_async_output_proc
=
False
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
"""Get the sliding window size, or None if disabled."""
"""Get the sliding window size, or None if disabled."""
...
@@ -1769,6 +1819,9 @@ class EngineConfig:
...
@@ -1769,6 +1819,9 @@ class EngineConfig:
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""Verify configs are valid & consistent with each other.
"""
"""
self
.
model_config
.
verify_async_output_proc
(
self
.
parallel_config
,
self
.
speculative_config
,
self
.
device_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
...
vllm/core/scheduler.py
View file @
2eedede8
...
@@ -4,7 +4,8 @@ import random
...
@@ -4,7 +4,8 @@ import random
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
(
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
...
@@ -299,6 +300,7 @@ class Scheduler:
...
@@ -299,6 +300,7 @@ class Scheduler:
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
pipeline_parallel_size
:
int
=
1
,
pipeline_parallel_size
:
int
=
1
,
output_proc_callback_fn
:
Optional
[
Callable
]
=
None
,
)
->
None
:
)
->
None
:
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
...
@@ -364,10 +366,36 @@ class Scheduler:
...
@@ -364,10 +366,36 @@ class Scheduler:
self
.
num_cumulative_preemption
:
int
=
0
self
.
num_cumulative_preemption
:
int
=
0
# Used to cache python objects
# Used to cache python objects
self
.
_scheduler_running_outputs_cache
:
PyObjectCache
=
PyObjectCache
(
self
.
_seq_group_metadata_cache
:
List
[
PyObjectCache
]
=
[]
scheduler_running_outputs_builder
)
self
.
_scheduler_running_outputs_cache
:
List
[
PyObjectCache
]
=
[]
self
.
_scheduled_seq_group_cache
:
PyObjectCache
=
PyObjectCache
(
self
.
_scheduled_seq_group_cache
:
List
[
PyObjectCache
]
=
[]
scheduled_seq_group_builder
)
# For async output processing, we need to swap cache buffers between
# iterations. I.e. since the output processing is lagged one step,
# we cannot reuse the cached objects immediately when the schedule()
# is called again, but only when schedule() is called the second time.
self
.
output_proc_callback_fn
=
output_proc_callback_fn
self
.
use_async_output_proc
=
self
.
output_proc_callback_fn
is
not
None
self
.
num_cache_iters
=
2
if
self
.
use_async_output_proc
else
1
self
.
cache_id
=
0
for
i
in
range
(
self
.
num_cache_iters
):
self
.
_seq_group_metadata_cache
.
append
(
PyObjectCache
(
seq_group_metadata_builder
))
self
.
_scheduler_running_outputs_cache
.
append
(
PyObjectCache
(
scheduler_running_outputs_builder
))
self
.
_scheduled_seq_group_cache
.
append
(
PyObjectCache
(
scheduled_seq_group_builder
))
# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self
.
_async_stopped
:
List
[
SequenceGroup
]
=
[]
@
property
def
next_cache_id
(
self
):
return
(
self
.
cache_id
+
1
)
%
self
.
num_cache_iters
@
property
@
property
def
lora_enabled
(
self
)
->
bool
:
def
lora_enabled
(
self
)
->
bool
:
...
@@ -483,7 +511,7 @@ class Scheduler:
...
@@ -483,7 +511,7 @@ class Scheduler:
SchedulerRunningOutputs.
SchedulerRunningOutputs.
"""
"""
ret
:
SchedulerRunningOutputs
=
\
ret
:
SchedulerRunningOutputs
=
\
self
.
_scheduler_running_outputs_cache
.
get_object
()
self
.
_scheduler_running_outputs_cache
[
self
.
cache_id
]
.
get_object
()
ret
.
blocks_to_swap_out
.
clear
()
ret
.
blocks_to_swap_out
.
clear
()
ret
.
blocks_to_copy
.
clear
()
ret
.
blocks_to_copy
.
clear
()
ret
.
decode_seq_groups
.
clear
()
ret
.
decode_seq_groups
.
clear
()
...
@@ -510,8 +538,12 @@ class Scheduler:
...
@@ -510,8 +538,12 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# to keep all the sequence groups in the RUNNING state.
running_queue
=
self
.
running
# Store original running requests for the case of async + preemption
if
self
.
use_async_output_proc
:
orig_running
=
self
.
running
.
copy
()
running_queue
=
self
.
running
assert
len
(
self
.
_async_stopped
)
==
0
while
running_queue
:
while
running_queue
:
seq_group
=
running_queue
[
0
]
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
num_running_tokens
=
self
.
_get_num_new_tokens
(
...
@@ -521,6 +553,28 @@ class Scheduler:
...
@@ -521,6 +553,28 @@ class Scheduler:
break
break
running_queue
.
popleft
()
running_queue
.
popleft
()
# With async postprocessor, an extra decode run is done
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if
self
.
use_async_output_proc
and
seq_group
.
seqs
[
0
].
get_len
(
)
>
self
.
scheduler_config
.
max_model_len
:
self
.
_async_stopped
.
append
(
seq_group
)
continue
# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if
self
.
use_async_output_proc
and
not
self
.
_can_append_slots
(
seq_group
):
tmp
=
self
.
running
self
.
running
=
orig_running
assert
self
.
output_proc_callback_fn
is
not
None
self
.
output_proc_callback_fn
(
is_async
=
True
)
self
.
running
=
tmp
while
not
self
.
_can_append_slots
(
seq_group
):
while
not
self
.
_can_append_slots
(
seq_group
):
budget
.
subtract_num_batched_tokens
(
seq_group
.
request_id
,
budget
.
subtract_num_batched_tokens
(
seq_group
.
request_id
,
num_running_tokens
)
num_running_tokens
)
...
@@ -556,7 +610,7 @@ class Scheduler:
...
@@ -556,7 +610,7 @@ class Scheduler:
is_prefill
=
seq_group
.
is_prefill
()
is_prefill
=
seq_group
.
is_prefill
()
scheduled_seq_group
:
ScheduledSequenceGroup
=
\
scheduled_seq_group
:
ScheduledSequenceGroup
=
\
self
.
_scheduled_seq_group_cache
.
get_object
()
self
.
_scheduled_seq_group_cache
[
self
.
cache_id
]
.
get_object
()
scheduled_seq_group
.
seq_group
=
seq_group
scheduled_seq_group
.
seq_group
=
seq_group
if
is_prefill
:
if
is_prefill
:
scheduled_seq_group
.
token_chunk_size
=
num_running_tokens
scheduled_seq_group
.
token_chunk_size
=
num_running_tokens
...
@@ -579,8 +633,8 @@ class Scheduler:
...
@@ -579,8 +633,8 @@ class Scheduler:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
curr_loras
.
add
(
seq_group
.
lora_int_id
)
self
.
_scheduler_running_outputs_cache
.
reset
()
self
.
_scheduler_running_outputs_cache
[
self
.
next_cache_id
]
.
reset
()
self
.
_scheduled_seq_group_cache
.
reset
()
self
.
_scheduled_seq_group_cache
[
self
.
next_cache_id
]
.
reset
()
return
ret
return
ret
...
@@ -1031,17 +1085,31 @@ class Scheduler:
...
@@ -1031,17 +1085,31 @@ class Scheduler:
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
)
)
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
def
_allow_async_output_proc
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
no_beam_search
=
(
seq_group
.
sampling_params
.
best_of
==
1
and
not
seq_group
.
sampling_params
.
use_beam_search
)
return
no_beam_search
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
bool
]:
# Schedule sequence groups.
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time
=
time
.
perf_counter
()
scheduler_start_time
=
time
.
perf_counter
()
scheduler_outputs
=
self
.
_schedule
()
scheduler_outputs
=
self
.
_schedule
()
now
=
time
.
time
()
now
=
time
.
time
()
if
not
self
.
cache_config
.
enable_prefix_caching
:
if
not
self
.
cache_config
.
enable_prefix_caching
:
common_computed_block_nums
=
[]
common_computed_block_nums
=
[]
# TODO: Combine multi-step and async postprocessor
allow_async_output_proc
:
bool
=
(
self
.
use_async_output_proc
and
not
self
.
scheduler_config
.
is_multi_step
)
# Create input data structures.
# Create input data structures.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
i
,
scheduled_seq_group
in
enumerate
(
for
i
,
scheduled_seq_group
in
enumerate
(
...
@@ -1050,6 +1118,11 @@ class Scheduler:
...
@@ -1050,6 +1118,11 @@ class Scheduler:
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
token_chunk_size
=
scheduled_seq_group
.
token_chunk_size
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group
.
maybe_set_first_scheduled_time
(
now
)
seq_group_metadata
=
self
.
_seq_group_metadata_cache
[
self
.
cache_id
].
get_object
()
seq_group_metadata
.
seq_data
.
clear
()
seq_group_metadata
.
block_tables
.
clear
()
# seq_id -> SequenceData
# seq_id -> SequenceData
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
# seq_id -> physical block numbers
# seq_id -> physical block numbers
...
@@ -1139,6 +1212,10 @@ class Scheduler:
...
@@ -1139,6 +1212,10 @@ class Scheduler:
)
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
if
allow_async_output_proc
:
allow_async_output_proc
=
self
.
_allow_async_output_proc
(
seq_group
)
# Now that the batch has been created, we can assume all blocks in the
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# This is because the engine assumes that a failure in model execution
...
@@ -1147,6 +1224,8 @@ class Scheduler:
...
@@ -1147,6 +1224,8 @@ class Scheduler:
self
.
block_manager
.
mark_blocks_as_computed
(
self
.
block_manager
.
mark_blocks_as_computed
(
scheduled_seq_group
.
seq_group
)
scheduled_seq_group
.
seq_group
)
self
.
_seq_group_metadata_cache
[
self
.
next_cache_id
].
reset
()
scheduler_time
=
time
.
perf_counter
()
-
scheduler_start_time
scheduler_time
=
time
.
perf_counter
()
-
scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
# running. This will help estimate if the scheduler is a significant
...
@@ -1158,7 +1237,12 @@ class Scheduler:
...
@@ -1158,7 +1237,12 @@ class Scheduler:
else
:
else
:
seq_group
.
metrics
.
scheduler_time
=
scheduler_time
seq_group
.
metrics
.
scheduler_time
=
scheduler_time
return
seq_group_metadata_list
,
scheduler_outputs
# Move to next cache (if exists)
self
.
cache_id
=
self
.
next_cache_id
# Return results
return
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
def
fork_seq
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
def
fork_seq
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
self
.
block_manager
.
fork
(
parent_seq
,
child_seq
)
self
.
block_manager
.
fork
(
parent_seq
,
child_seq
)
...
@@ -1167,6 +1251,12 @@ class Scheduler:
...
@@ -1167,6 +1251,12 @@ class Scheduler:
"""Free a sequence from a block table."""
"""Free a sequence from a block table."""
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
def
_free_finished_seqs
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
"""Free finished seqs in a sequence group."""
for
seq
in
seq_group
.
get_seqs
():
if
seq
.
is_finished
():
self
.
free_seq
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
def
free_finished_seq_groups
(
self
)
->
None
:
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
remaining
:
Deque
[
SequenceGroup
]
=
deque
()
for
seq_group
in
self
.
running
:
for
seq_group
in
self
.
running
:
...
@@ -1179,8 +1269,24 @@ class Scheduler:
...
@@ -1179,8 +1269,24 @@ class Scheduler:
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
else
:
else
:
remaining
.
append
(
seq_group
)
remaining
.
append
(
seq_group
)
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
self
.
running
=
remaining
self
.
running
=
remaining
# Handle async stopped sequence groups
# (ones that reached max model len)
if
self
.
_async_stopped
:
for
seq_group
in
self
.
_async_stopped
:
self
.
_free_seq_group_cross_attn_blocks
(
seq_group
)
self
.
_finished_requests_ids
.
append
(
seq_group
.
request_id
)
# Free finished seqs
self
.
_free_finished_seqs
(
seq_group
)
self
.
_async_stopped
.
clear
()
def
_allocate_and_set_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
def
_allocate_and_set_running
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
self
.
block_manager
.
allocate
(
seq_group
)
self
.
block_manager
.
allocate
(
seq_group
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
...
...
vllm/engine/arg_utils.py
View file @
2eedede8
...
@@ -147,6 +147,7 @@ class EngineArgs:
...
@@ -147,6 +147,7 @@ class EngineArgs:
otlp_traces_endpoint
:
Optional
[
str
]
=
None
otlp_traces_endpoint
:
Optional
[
str
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
...
@@ -733,6 +734,12 @@ class EngineArgs:
...
@@ -733,6 +734,12 @@ class EngineArgs:
"modules. This involves use of possibly costly and or blocking "
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact."
)
"operations and hence might have a performance impact."
)
parser
.
add_argument
(
'--disable-async-output-proc'
,
action
=
'store_true'
,
default
=
EngineArgs
.
disable_async_output_proc
,
help
=
"Disable async output processing. This may result in "
"lower performance."
)
return
parser
return
parser
@
classmethod
@
classmethod
...
@@ -792,6 +799,7 @@ class EngineArgs:
...
@@ -792,6 +799,7 @@ class EngineArgs:
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
served_model_name
=
self
.
served_model_name
,
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
)
)
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
...
...
vllm/engine/async_llm_engine.py
View file @
2eedede8
...
@@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine):
cached_outputs
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
cached_outputs
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# skip the scheduler if there are any remaining steps in the seq groups.
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# This ensures that the scheduler is only called again when the current
# batch has completed.
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
[
(
seq_group_metadata_list
,
scheduler_outputs
,
virtual_engine
].
schedule
()
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if
not
allow_async_output_proc
and
len
(
self
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
is_async
=
True
)
if
(
self
.
scheduler_config
.
is_multi_step
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
# lookahead slots
self
.
_cache_scheduler_outputs_for_multi_step
(
self
.
_cache_scheduler_outputs_for_multi_step
(
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
)
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
assert
scheduler_outputs
is
not
None
assert
not
(
self
.
scheduler_config
.
is_multi_step
and
\
allow_async_output_proc
)
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
virtual_engine
].
get_and_reset_finished_requests_ids
()
...
@@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine):
# We use ExecuteModelRequest to pass the last sampled_token_ids
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
output_proc_callback_fn
=
\
self
.
_process_model_outputs
# Execute the model.
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
output
=
await
self
.
model_executor
.
execute_model_async
(
execute_model_req
)
execute_model_req
)
...
@@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine):
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
else
:
if
len
(
self
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
)
output
=
[]
output
=
[]
# Finish the current step for all the sequence groups.
# Finish the current step for all the sequence groups.
...
@@ -337,11 +358,21 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -337,11 +358,21 @@ class _AsyncLLMEngine(LLMEngine):
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
virtual_engine
]
=
SchedulerOutputState
()
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
# Cache results in engine
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
self
.
output_queue
.
append
(
else
:
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
request_outputs
=
[]
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
"Multi step decoding does not work with async output processing."
# noqa: E501
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
is_async
=
False
)
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
@@ -349,7 +380,10 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -349,7 +380,10 @@ class _AsyncLLMEngine(LLMEngine):
# Tracing
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
return
request_outputs
else
:
self
.
request_outputs
=
[]
return
self
.
request_outputs
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
"""Stop the remote worker execution loop."""
"""Stop the remote worker execution loop."""
...
...
vllm/engine/llm_engine.py
View file @
2eedede8
import
time
import
time
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
Iterable
,
List
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
Mapping
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
Union
from
typing
import
Set
,
Tuple
,
Type
,
Union
...
@@ -38,9 +39,8 @@ from vllm.pooling_params import PoolingParams
...
@@ -38,9 +39,8 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
,
Sequence
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceStatus
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.config
import
try_get_generation_config
...
@@ -82,9 +82,10 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
...
@@ -82,9 +82,10 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
@
dataclass
@
dataclass
class
SchedulerOutputState
:
class
SchedulerOutputState
:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output
:
Optional
[
SamplerOutput
]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
allow_async_output_proc
:
bool
=
False
last_output
:
Optional
[
SamplerOutput
]
=
None
class
LLMEngine
:
class
LLMEngine
:
...
@@ -190,6 +191,9 @@ class LLMEngine:
...
@@ -190,6 +191,9 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only
:
bool
=
False
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
"Initializing an LLM engine (v%s) with config: "
...
@@ -204,7 +208,8 @@ class LLMEngine:
...
@@ -204,7 +208,8 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s)"
,
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s)"
,
VLLM_VERSION
,
VLLM_VERSION
,
model_config
.
model
,
model_config
.
model
,
speculative_config
,
speculative_config
,
...
@@ -235,6 +240,7 @@ class LLMEngine:
...
@@ -235,6 +240,7 @@ class LLMEngine:
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
num_scheduler_steps
,
scheduler_config
.
num_scheduler_steps
,
cache_config
.
enable_prefix_caching
,
cache_config
.
enable_prefix_caching
,
model_config
.
use_async_output_proc
,
)
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
from
vllm.plugins
import
load_general_plugins
...
@@ -253,6 +259,7 @@ class LLMEngine:
...
@@ -253,6 +259,7 @@ class LLMEngine:
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
)
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
step_return_finished_only
=
step_return_finished_only
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
@@ -340,8 +347,11 @@ class LLMEngine:
...
@@ -340,8 +347,11 @@ class LLMEngine:
# NOTE: the cache_config here have been updated with the numbers of
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
# GPU and CPU blocks, which are profiled in the distributed executor.
self
.
scheduler
=
[
self
.
scheduler
=
[
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
,
Scheduler
(
parallel_config
.
pipeline_parallel_size
)
scheduler_config
,
cache_config
,
lora_config
,
parallel_config
.
pipeline_parallel_size
,
self
.
_process_model_outputs
if
model_config
.
use_async_output_proc
else
None
)
for
_
in
range
(
parallel_config
.
pipeline_parallel_size
)
for
_
in
range
(
parallel_config
.
pipeline_parallel_size
)
]
]
...
@@ -396,6 +406,13 @@ class LLMEngine:
...
@@ -396,6 +406,13 @@ class LLMEngine:
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
# Async output processing pointers
self
.
output_queue
:
Deque
[
Tuple
[
List
[
SamplerOutput
],
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -1197,34 +1214,66 @@ class LLMEngine:
...
@@ -1197,34 +1214,66 @@ class LLMEngine:
return
return
def
_process_model_outputs
(
def
_process_model_outputs
(
self
,
self
,
is_async
:
bool
,
output
:
GenericSequence
[
Union
[
SamplerOutput
,
PoolerOutput
]],
clear_outputs
:
bool
=
True
)
->
None
:
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
clear_outputs: Sometimes existing outputs need to be combined
with outputs of this call. This happens for postprocessor
draining at the final stage (like when sequences are finished)
Returns RequestOutputs that can be returned to the client.
Returns RequestOutputs that can be returned to the client.
"""
"""
now
=
time
.
time
()
now
=
time
.
time
()
# Organize outputs by [sequence group][step] instead of
if
clear_outputs
:
# [step][sequence group].
self
.
request_outputs
.
clear
()
output_by_sequence_group
=
create_output_by_sequence_group
(
output
,
num_seq_groups
=
len
(
scheduled_seq_groups
))
if
len
(
self
.
output_queue
)
==
0
:
return
None
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
)
=
self
.
output_queue
.
popleft
()
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
scheduler_outputs
.
scheduled_seq_groups
)
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if
len
(
outputs
)
>
1
:
outputs_by_sequence_group
=
create_output_by_sequence_group
(
outputs
,
num_seq_groups
=
len
(
seq_group_metadata_list
))
else
:
outputs_by_sequence_group
=
outputs
finished_before
:
List
[
int
]
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
# Update the scheduled sequence groups with the model outputs.
for
scheduled_seq_group
,
outputs
,
seq_group_meta
in
zip
(
scheduled_seq_groups
,
output_by_sequence_group
,
seq_group_metadata_list
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
finished_before
.
append
(
i
)
continue
if
len
(
outputs
)
>
1
:
output
=
outputs_by_sequence_group
[
i
]
else
:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
if
not
is_async
:
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
scheduled_seq_group
.
token_chunk_size
)
if
output
is
not
None
and
len
(
output
)
>
0
:
for
o
in
output
:
if
outputs
:
for
o
in
outputs
:
if
(
isinstance
(
o
,
SamplerOutput
)
if
(
isinstance
(
o
,
SamplerOutput
)
and
seq_group
.
metrics
is
not
None
):
and
seq_group
.
metrics
is
not
None
):
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
...
@@ -1239,30 +1288,75 @@ class LLMEngine:
...
@@ -1239,30 +1288,75 @@ class LLMEngine:
else
:
else
:
seq_group
.
metrics
.
model_execute_time
=
(
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
o
.
model_execute_time
)
if
self
.
model_config
.
embedding_mode
:
if
self
.
model_config
.
embedding_mode
:
self
.
_process_sequence_group_outputs
(
seq_group
,
output
s
)
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
continue
continue
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
s
)
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group_meta
.
do_sample
:
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
seq_group
,
outputs
)
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
# Free the finished sequence groups.
# Free the finished sequence groups.
for
scheduler
in
self
.
scheduler
:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
scheduler
.
free_finished_seq_groups
()
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
Union
[
RequestOutput
,
for
i
,
_
in
enumerate
(
seq_group_metadata_list
):
EmbeddingRequestOutput
]]
=
[]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
for
scheduled_seq_group
in
scheduled_seq_groups
:
if
i
in
finished_before
:
continue
# Avoids double processing
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
if
(
seq_group
.
is_finished
()
if
self
.
step_return_finished_only
else
True
):
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_outputs
.
append
(
request_output
)
self
.
request_outputs
.
append
(
request_output
)
for
seq_group
in
ignored_seq_groups
:
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_outputs
.
append
(
request_output
)
self
.
request_outputs
.
append
(
request_output
)
return
request_outputs
if
is_async
:
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
return
None
def
_advance_to_next_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done inside output processor, but it is
required if the worker is to perform async forward pass to next step.
"""
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
,
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
continue
seq_group
.
update_num_computed_tokens
(
seq_group_metadata
.
token_chunk_size
)
if
seq_group_metadata
.
do_sample
:
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)"
)
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
=
seq_group
.
seqs
[
0
]
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
...
@@ -1325,24 +1419,32 @@ class LLMEngine:
...
@@ -1325,24 +1419,32 @@ class LLMEngine:
cached_outputs
=
self
.
cached_scheduler_outputs
[
0
]
cached_outputs
=
self
.
cached_scheduler_outputs
[
0
]
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
allow_async_output_proc
=
cached_outputs
.
allow_async_output_proc
# Skip the scheduler if there are any remaining steps in the seq groups.
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# This ensures that the scheduler is only called again when the current
# batch has completed.
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
[
(
seq_group_metadata_list
,
scheduler_outputs
,
0
].
schedule
()
allow_async_output_proc
)
=
self
.
scheduler
[
0
].
schedule
()
if
not
allow_async_output_proc
and
len
(
self
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
is_async
=
True
)
if
(
self
.
scheduler_config
.
is_multi_step
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
# lookahead slots
self
.
_cache_scheduler_outputs_for_multi_step
(
self
.
_cache_scheduler_outputs_for_multi_step
(
0
,
seq_group_metadata_list
,
scheduler_outputs
)
0
,
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
assert
scheduler_outputs
is
not
None
assert
not
(
self
.
scheduler_config
.
is_multi_step
and
\
allow_async_output_proc
)
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
finished_requests_ids
=
self
.
scheduler
[
finished_requests_ids
=
self
.
scheduler
[
0
].
get_and_reset_finished_requests_ids
()
0
].
get_and_reset_finished_requests_ids
()
...
@@ -1366,6 +1468,10 @@ class LLMEngine:
...
@@ -1366,6 +1468,10 @@ class LLMEngine:
# to each of the non-last PP stages for in-place prepare_input.
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
output_proc_callback_fn
=
\
self
.
_process_model_outputs
output
=
self
.
model_executor
.
execute_model
(
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
...
@@ -1374,6 +1480,9 @@ class LLMEngine:
...
@@ -1374,6 +1480,9 @@ class LLMEngine:
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
0
,
output
)
self
.
_update_cached_scheduler_output
(
0
,
output
)
else
:
else
:
if
len
(
self
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
)
output
=
[]
output
=
[]
# Finish the current step for all the sequence groups.
# Finish the current step for all the sequence groups.
...
@@ -1382,23 +1491,41 @@ class LLMEngine:
...
@@ -1382,23 +1491,41 @@ class LLMEngine:
seq_group
.
finish_step
()
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# clear the cache if we have finished all the steps
# clear the cache if we have finished all the steps
.
if
self
.
scheduler_config
.
is_multi_step
:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
else
:
# Add results to the output_queue
request_outputs
=
[]
# (for async or non-async postprocessing)
self
.
output_queue
.
append
(
(
output
,
seq_group_metadata_list
,
scheduler_outputs
))
if
output
and
allow_async_output_proc
:
assert
len
(
output
)
==
1
,
(
"Multi step decoding does not work "
"with async output processing."
)
self
.
_advance_to_next_step
(
output
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
is_async
=
False
)
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
# Tracing
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
)
else
:
self
.
request_outputs
=
[]
if
not
self
.
has_unfinished_requests
():
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor
if
len
(
self
.
output_queue
)
>
0
:
assert
not
self
.
scheduler_config
.
is_multi_step
self
.
_process_model_outputs
(
is_async
=
True
,
clear_outputs
=
False
)
assert
len
(
self
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# torch.distributed ops which may otherwise timeout, and unblocks
...
@@ -1406,7 +1533,7 @@ class LLMEngine:
...
@@ -1406,7 +1533,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
# queued control plane messages, such as add/remove lora adapters.
self
.
model_executor
.
stop_remote_worker_execution_loop
()
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
request_outputs
return
self
.
request_outputs
def
_has_remaining_steps
(
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
...
@@ -1431,12 +1558,14 @@ class LLMEngine:
...
@@ -1431,12 +1558,14 @@ class LLMEngine:
def
_cache_scheduler_outputs_for_multi_step
(
def
_cache_scheduler_outputs_for_multi_step
(
self
,
virtual_engine
:
int
,
self
,
virtual_engine
:
int
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
scheduler_outputs
:
SchedulerOutputs
)
->
None
:
scheduler_outputs
:
SchedulerOutputs
,
self
.
cached_scheduler_outputs
[
allow_async_output_proc
:
bool
)
->
None
:
virtual_engine
].
seq_group_metadata_list
=
seq_group_metadata_list
co
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
self
.
cached_scheduler_outputs
[
virtual_engine
].
scheduler_outputs
=
\
scheduler_outputs
co
.
seq_group_metadata_list
=
seq_group_metadata_list
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
None
co
.
scheduler_outputs
=
scheduler_outputs
co
.
allow_async_output_proc
=
allow_async_output_proc
co
.
last_output
=
None
def
_update_cached_scheduler_output
(
def
_update_cached_scheduler_output
(
self
,
virtual_engine
:
int
,
self
,
virtual_engine
:
int
,
...
@@ -1472,20 +1601,21 @@ class LLMEngine:
...
@@ -1472,20 +1601,21 @@ class LLMEngine:
raise
KeyError
(
f
"Logger with name
{
logger_name
}
does not exist."
)
raise
KeyError
(
f
"Logger with name
{
logger_name
}
does not exist."
)
del
self
.
stat_loggers
[
logger_name
]
del
self
.
stat_loggers
[
logger_name
]
def
do_log_stats
(
def
do_log_stats
(
self
,
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
None
:
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
,
finished_before
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
"""Forced log when no requests active."""
"""Forced log when no requests active."""
if
self
.
log_stats
:
if
self
.
log_stats
:
stats
=
self
.
_get_stats
(
scheduler_outputs
,
model_output
)
stats
=
self
.
_get_stats
(
scheduler_outputs
,
model_output
,
finished_before
)
for
logger
in
self
.
stat_loggers
.
values
():
for
logger
in
self
.
stat_loggers
.
values
():
logger
.
log
(
stats
)
logger
.
log
(
stats
)
def
_get_stats
(
def
_get_stats
(
self
,
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
],
scheduler_outputs
:
Optional
[
SchedulerOutputs
],
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
Stats
:
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
,
finished_before
:
Optional
[
List
[
int
]]
=
None
)
->
Stats
:
"""Get Stats to be Logged to Prometheus.
"""Get Stats to be Logged to Prometheus.
Args:
Args:
...
@@ -1550,6 +1680,10 @@ class LLMEngine:
...
@@ -1550,6 +1680,10 @@ class LLMEngine:
# NOTE: This loop assumes prefill seq_groups are before
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
# decode seq_groups in scheduled_seq_groups.
if
scheduler_outputs
is
not
None
:
if
scheduler_outputs
is
not
None
:
# For async postprocessor, already finished sequences need to be
# not counted (to avoid double counting)
actual_num_batched_tokens
=
scheduler_outputs
.
num_batched_tokens
# type: ignore
num_generation_tokens_from_prefill_groups
=
0.
num_generation_tokens_from_prefill_groups
=
0.
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# the len of scheduler_outputs.scheduled_seq_groups is !=
# the len of scheduler_outputs.scheduled_seq_groups is !=
...
@@ -1558,6 +1692,11 @@ class LLMEngine:
...
@@ -1558,6 +1692,11 @@ class LLMEngine:
for
idx
,
scheduled_seq_group
in
enumerate
(
for
idx
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
scheduler_outputs
.
scheduled_seq_groups
):
# Skip double logging when using async output proc
if
finished_before
and
idx
in
finished_before
:
actual_num_batched_tokens
-=
1
continue
group_was_prefill
=
idx
<
scheduler_outputs
.
num_prefill_groups
group_was_prefill
=
idx
<
scheduler_outputs
.
num_prefill_groups
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
...
@@ -1592,7 +1731,6 @@ class LLMEngine:
...
@@ -1592,7 +1731,6 @@ class LLMEngine:
# Latency timings
# Latency timings
time_e2e_requests
.
append
(
now
-
time_e2e_requests
.
append
(
now
-
seq_group
.
metrics
.
arrival_time
)
seq_group
.
metrics
.
arrival_time
)
# Metadata
# Metadata
num_prompt_tokens_requests
.
append
(
num_prompt_tokens_requests
.
append
(
len
(
seq_group
.
prompt_token_ids
))
len
(
seq_group
.
prompt_token_ids
))
...
@@ -1616,7 +1754,7 @@ class LLMEngine:
...
@@ -1616,7 +1754,7 @@ class LLMEngine:
# + num_generation_tokens_from_prefill_groups (since we generate
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter
=
(
num_generation_tokens_iter
=
(
scheduler_outputs
.
num_batched_tokens
-
num_prompt_tokens_iter
+
actual_
num_batched_tokens
-
num_prompt_tokens_iter
+
num_generation_tokens_from_prefill_groups
)
num_generation_tokens_from_prefill_groups
)
# Spec decode, if enabled, emits specialized metrics from the worker in
# Spec decode, if enabled, emits specialized metrics from the worker in
...
...
vllm/engine/output_processor/interfaces.py
View file @
2eedede8
...
@@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC):
...
@@ -40,13 +40,9 @@ class SequenceGroupOutputProcessor(ABC):
# Importing here to avoid cycle.
# Importing here to avoid cycle.
from
vllm.engine.output_processor.single_step
import
(
from
vllm.engine.output_processor.single_step
import
(
SingleStepOutputProcessor
)
SingleStepOutputProcessor
)
return
SingleStepOutputProcessor
(
return
SingleStepOutputProcessor
(
scheduler_config
,
detokenizer
,
scheduler_config
,
scheduler
,
seq_counter
,
detokenizer
,
stop_checker
)
scheduler
,
seq_counter
,
stop_checker
,
)
else
:
else
:
# Importing here to avoid cycle.
# Importing here to avoid cycle.
from
vllm.engine.output_processor.multi_step
import
(
from
vllm.engine.output_processor.multi_step
import
(
...
@@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC):
...
@@ -61,7 +57,8 @@ class SequenceGroupOutputProcessor(ABC):
@
abstractmethod
@
abstractmethod
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
],
is_async
:
bool
)
->
None
:
"""Process new token ids for the sequence group. Handles logic such as
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
scheduler.
...
...
vllm/engine/output_processor/multi_step.py
View file @
2eedede8
...
@@ -57,17 +57,28 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -57,17 +57,28 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"Prompt logprob is not supported by multi step workers. "
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers)."
)
"(e.g., speculative decode uses multi step workers)."
)
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
def
process_outputs
(
self
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
],
is_async
:
bool
=
False
)
->
None
:
"""Append new tokens in the outputs to sequences in the sequence group.
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
one new token per sequence.
This applies logic like stop condition checking and detokenization,
This applies logic like stop condition checking and detokenization.
including freeing finished sequences. It also handles cases where there
It also handles cases where there are tokens emitted after
are tokens emitted after the EOS token.
the EOS token.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
"""
# TODO: Add support for async if necessary
assert
not
is_async
seqs
=
sequence_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
seqs
=
sequence_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
assert
seqs
,
"expected running sequences"
assert
seqs
,
"expected running sequences"
...
@@ -138,7 +149,3 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -138,7 +149,3 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
)
)
if
seq
.
is_finished
():
if
seq
.
is_finished
():
break
break
if
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
vllm/engine/output_processor/single_step.py
View file @
2eedede8
...
@@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
that is currently difficult to schedule multiple steps ahead of time.
that is currently difficult to schedule multiple steps ahead of time.
"""
"""
def
__init__
(
def
__init__
(
self
,
scheduler_config
:
SchedulerConfig
,
self
,
detokenizer
:
Detokenizer
,
scheduler
:
List
[
Scheduler
],
scheduler_config
:
SchedulerConfig
,
seq_counter
:
Counter
,
stop_checker
:
StopChecker
):
detokenizer
:
Detokenizer
,
scheduler
:
List
[
Scheduler
],
seq_counter
:
Counter
,
stop_checker
:
StopChecker
,
):
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
detokenizer
=
detokenizer
self
.
detokenizer
=
detokenizer
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
@@ -44,16 +39,24 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -44,16 +39,24 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self
.
stop_checker
=
stop_checker
self
.
stop_checker
=
stop_checker
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
],
is_async
:
bool
)
->
None
:
"""Append all new tokens to sequences in the sequence group. Fork any
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
as finished if they meet stop conditions.
is_async - Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
"""
"""
assert
(
len
(
outputs
)
==
1
assert
(
len
(
outputs
)
==
1
),
f
"
{
type
(
self
)
}
does not support multiple outputs per step"
),
f
"
{
type
(
self
)
}
does not support multiple outputs per step"
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
])
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
],
is_async
)
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
...
@@ -80,13 +83,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -80,13 +83,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
outputs
:
SequenceGroupOutput
,
is_async
:
bool
)
->
None
:
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
.
n
==
1
and
not
sampling_params
.
use_beam_search
:
if
sampling_params
.
n
==
1
and
not
sampling_params
.
use_beam_search
:
# only have one output sample
# only have one output sample
sample
=
outputs
.
samples
[
0
]
sample
=
outputs
.
samples
[
0
]
# only have one sequence
# only have one sequence
seq
=
seq_group
.
seqs
[
0
]
seq
=
seq_group
.
seqs
[
0
]
if
not
is_async
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
if
sampling_params
.
detokenize
and
self
.
detokenizer
:
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
new_char_count
=
self
.
detokenizer
.
decode_sequence_inplace
(
...
@@ -104,6 +109,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -104,6 +109,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler
.
free_seq
(
seq
)
scheduler
.
free_seq
(
seq
)
return
return
# TODO: Add support for async for beam search
assert
not
is_async
# Process samples
# Process samples
samples
=
outputs
.
samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
...
...
vllm/entrypoints/llm.py
View file @
2eedede8
...
@@ -129,6 +129,7 @@ class LLM:
...
@@ -129,6 +129,7 @@ class LLM:
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
int
=
8192
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
'''
'''
...
@@ -170,6 +171,7 @@ class LLM:
...
@@ -170,6 +171,7 @@ class LLM:
max_context_len_to_capture
=
max_context_len_to_capture
,
max_context_len_to_capture
=
max_context_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
max_seq_len_to_capture
=
max_seq_len_to_capture
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
disable_async_output_proc
=
disable_async_output_proc
,
**
kwargs
,
**
kwargs
,
)
)
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
...
@@ -603,7 +605,6 @@ class LLM:
...
@@ -603,7 +605,6 @@ class LLM:
inputs
=
[
inputs
]
inputs
=
[
inputs
]
num_requests
=
len
(
inputs
)
num_requests
=
len
(
inputs
)
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
"must be the same."
)
...
@@ -678,6 +679,10 @@ class LLM:
...
@@ -678,6 +679,10 @@ class LLM:
postfix
=
(
f
"est. speed input:
{
0
:.
2
f
}
toks/s, "
postfix
=
(
f
"est. speed input:
{
0
:.
2
f
}
toks/s, "
f
"output:
{
0
:.
2
f
}
toks/s"
),
f
"output:
{
0
:.
2
f
}
toks/s"
),
)
)
# In the loop below, only finished outputs are used
self
.
llm_engine
.
step_return_finished_only
=
True
# Run the engine.
# Run the engine.
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
total_in_toks
=
0
total_in_toks
=
0
...
@@ -700,6 +705,10 @@ class LLM:
...
@@ -700,6 +705,10 @@ class LLM:
f
"est. speed input:
{
in_spd
:.
2
f
}
toks/s, "
f
"est. speed input:
{
in_spd
:.
2
f
}
toks/s, "
f
"output:
{
out_spd
:.
2
f
}
toks/s"
)
f
"output:
{
out_spd
:.
2
f
}
toks/s"
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
# Restore original behavior
self
.
llm_engine
.
step_return_finished_only
=
False
if
use_tqdm
:
if
use_tqdm
:
pbar
.
close
()
pbar
.
close
()
# Sort the outputs by request ID.
# Sort the outputs by request ID.
...
...
vllm/executor/distributed_gpu_executor.py
View file @
2eedede8
...
@@ -65,7 +65,8 @@ class DistributedGPUExecutor(GPUExecutor):
...
@@ -65,7 +65,8 @@ class DistributedGPUExecutor(GPUExecutor):
def
execute_model
(
def
execute_model
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
if
self
.
parallel_worker_tasks
is
None
:
if
self
.
parallel_worker_tasks
is
None
:
self
.
parallel_worker_tasks
=
self
.
_run_workers
(
self
.
parallel_worker_tasks
=
self
.
_run_workers
(
"start_worker_execution_loop"
,
"start_worker_execution_loop"
,
...
@@ -188,7 +189,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
...
@@ -188,7 +189,7 @@ class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@
abstractmethod
@
abstractmethod
async
def
_driver_execute_model_async
(
async
def
_driver_execute_model_async
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
"""Execute the model asynchronously in the driver worker.
"""Execute the model asynchronously in the driver worker.
...
...
vllm/executor/gpu_executor.py
View file @
2eedede8
...
@@ -176,5 +176,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
...
@@ -176,5 +176,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
execute_model_req
=
execute_model_req
,
)
)(
execute_model_req
=
execute_model_req
)
return
output
return
output
vllm/sequence.py
View file @
2eedede8
...
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
...
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from
array
import
array
from
array
import
array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Tuple
,
Union
,
cast
)
Optional
,
Set
,
Tuple
,
Union
,
cast
)
import
msgspec
import
msgspec
import
torch
import
torch
...
@@ -474,11 +474,8 @@ class Sequence:
...
@@ -474,11 +474,8 @@ class Sequence:
"""Reset the sequence states for recomputation."""
"""Reset the sequence states for recomputation."""
self
.
data
.
reset_state_for_recompute
()
self
.
data
.
reset_state_for_recompute
()
def
append_token_id
(
def
append_token_id
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
self
,
Logprob
])
->
None
:
token_id
:
int
,
logprobs
:
Dict
[
int
,
Logprob
],
)
->
None
:
assert
token_id
in
logprobs
assert
token_id
in
logprobs
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
...
@@ -1293,6 +1290,8 @@ class ExecuteModelRequest(
...
@@ -1293,6 +1290,8 @@ class ExecuteModelRequest(
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
# The last sampled token ids for multi step decoding.
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async postprocessor
output_proc_callback_fn
:
Optional
[
Callable
]
=
None
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
...
@@ -1338,4 +1337,5 @@ class ExecuteModelRequest(
...
@@ -1338,4 +1337,5 @@ class ExecuteModelRequest(
num_steps
=
self
.
num_steps
,
num_steps
=
self
.
num_steps
,
finished_requests_ids
=
self
.
finished_requests_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
)
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
output_proc_callback_fn
=
self
.
output_proc_callback_fn
)
vllm/worker/model_runner.py
View file @
2eedede8
...
@@ -6,8 +6,8 @@ import time
...
@@ -6,8 +6,8 @@ import time
import
warnings
import
warnings
import
weakref
import
weakref
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
TypeVar
,
Union
)
Tuple
,
Type
,
TypeVar
,
Union
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -90,6 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
...
@@ -90,6 +90,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
request_ids_to_seq_ids
:
Optional
[
Dict
[
str
,
List
[
int
]]]
=
None
request_ids_to_seq_ids
:
Optional
[
Dict
[
str
,
List
[
int
]]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
output_proc_callback_fn
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
@@ -1327,7 +1328,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1327,7 +1328,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
)
->
ModelInputForGPUWithSamplingMetadata
:
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
metadata for the sampling step.
...
@@ -1451,6 +1452,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1451,6 +1452,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
not
self
.
is_driver_worker
:
if
not
self
.
is_driver_worker
:
return
[]
return
[]
if
model_input
.
output_proc_callback_fn
is
not
None
:
model_input
.
output_proc_callback_fn
(
is_async
=
True
)
# Sample the next token.
# Sample the next token.
output
:
SamplerOutput
=
self
.
model
.
sample
(
output
:
SamplerOutput
=
self
.
model
.
sample
(
logits
=
logits
,
logits
=
logits
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment