Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1a2aef3e
Unverified
Commit
1a2aef3e
authored
Sep 23, 2024
by
Alexander Matveev
Committed by
GitHub
Sep 23, 2024
Browse files
Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse (#8335)
parent
5f7bb584
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
142 additions
and
42 deletions
+142
-42
tests/entrypoints/openai/test_accuracy.py
tests/entrypoints/openai/test_accuracy.py
+5
-1
vllm/config.py
vllm/config.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+28
-9
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+8
-1
vllm/outputs.py
vllm/outputs.py
+74
-22
vllm/sequence.py
vllm/sequence.py
+19
-9
No files found.
tests/entrypoints/openai/test_accuracy.py
View file @
1a2aef3e
...
...
@@ -19,7 +19,11 @@ FILTER = "exact_match,strict-match"
RTOL
=
0.03
EXPECTED_VALUE
=
0.58
DEFAULT_ARGS
=
[
"--max-model-len"
,
"4096"
,
"--disable-log-requests"
]
MORE_ARGS_LIST
=
[[
"--enable-chunked-prefill"
],
[
"--num-scheduler-steps"
,
"8"
]]
MORE_ARGS_LIST
=
[
[
"--enable-chunked-prefill"
],
# Chunked
[
"--num-scheduler-steps"
,
"8"
],
# MS
[
"--num-scheduler-steps"
,
"8"
,
"--multi-step-stream-outputs"
]
# MS+Stream
]
@
pytest
.
mark
.
parametrize
(
"more_args"
,
MORE_ARGS_LIST
)
...
...
vllm/config.py
View file @
1a2aef3e
...
...
@@ -960,6 +960,7 @@ class SchedulerConfig:
is_multimodal_model
:
bool
=
False
,
preemption_mode
:
Optional
[
str
]
=
None
,
num_scheduler_steps
:
int
=
1
,
multi_step_stream_outputs
:
bool
=
False
,
send_delta_data
:
bool
=
False
)
->
None
:
if
max_num_batched_tokens
is
None
:
if
enable_chunked_prefill
:
...
...
@@ -1000,6 +1001,7 @@ class SchedulerConfig:
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
multi_step_stream_outputs
=
multi_step_stream_outputs
self
.
send_delta_data
=
send_delta_data
self
.
_verify_args
()
...
...
vllm/engine/arg_utils.py
View file @
1a2aef3e
...
...
@@ -145,6 +145,7 @@ class EngineArgs:
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'auto'
num_scheduler_steps
:
int
=
1
multi_step_stream_outputs
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
num_gpu_blocks_override
:
Optional
[
int
]
=
None
num_lookahead_slots
:
int
=
0
...
...
@@ -595,6 +596,10 @@ class EngineArgs:
help
=
(
'Maximum number of forward steps per '
'scheduler call.'
))
parser
.
add_argument
(
'--multi-step-stream-outputs'
,
action
=
'store_true'
,
help
=
'If True, then multi-step will stream outputs for every step'
)
parser
.
add_argument
(
'--scheduler-delay-factor'
,
type
=
float
,
...
...
@@ -999,6 +1004,7 @@ class EngineArgs:
is_multimodal_model
=
model_config
.
is_multimodal_model
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
multi_step_stream_outputs
=
self
.
multi_step_stream_outputs
,
send_delta_data
=
(
envs
.
VLLM_USE_RAY_SPMD_WORKER
and
parallel_config
.
use_ray
),
)
...
...
vllm/engine/llm_engine.py
View file @
1a2aef3e
...
...
@@ -95,7 +95,7 @@ class OutputData(NamedTuple):
class
SchedulerContext
:
def
__init__
(
self
):
def
__init__
(
self
,
multi_step_stream_outputs
:
bool
=
False
):
self
.
output_queue
:
Deque
[
OutputData
]
=
deque
()
self
.
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
...
...
@@ -103,6 +103,8 @@ class SchedulerContext:
List
[
SequenceGroupMetadata
]]
=
None
self
.
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
self
.
multi_step_stream_outputs
:
bool
=
multi_step_stream_outputs
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
...
...
@@ -219,6 +221,7 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
)
->
None
:
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
...
...
@@ -234,8 +237,9 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s)"
,
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)"
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
...
...
@@ -266,8 +270,10 @@ class LLMEngine:
model_config
.
served_model_name
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
num_scheduler_steps
,
scheduler_config
.
multi_step_stream_outputs
,
cache_config
.
enable_prefix_caching
,
model_config
.
use_async_output_proc
,
use_cached_outputs
,
model_config
.
mm_processor_kwargs
,
)
# TODO(woosuk): Print more configs in debug mode.
...
...
@@ -287,6 +293,7 @@ class LLMEngine:
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
...
@@ -379,7 +386,8 @@ class LLMEngine:
]
self
.
scheduler_contexts
=
[
SchedulerContext
()
SchedulerContext
(
multi_step_stream_outputs
=
self
.
scheduler_config
.
multi_step_stream_outputs
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
...
...
@@ -998,7 +1006,8 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
...
...
@@ -1019,8 +1028,8 @@ class LLMEngine:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
# For multi-step, do
no
t create outputs each iteration
if
not
is_last_step
:
# For multi-step
without streaming
, do
n'
t create outputs each iteration
if
not
is_last_step
and
not
ctx
.
multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
...
...
@@ -1037,17 +1046,27 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
# For multi-step with streaming, create outputs each iteration
if
not
is_last_step
and
ctx
.
multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
self
.
process_request_outputs_callback
is
not
None
:
self
.
process_request_outputs_callback
(
ctx
.
request_outputs
)
ctx
.
request_outputs
.
clear
()
return
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
params
=
seq_group
.
sampling_params
if
params
is
not
None
and
params
.
output_kind
==
(
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
...
...
vllm/engine/multiprocessing/engine.py
View file @
1a2aef3e
...
...
@@ -66,7 +66,14 @@ class MQLLMEngine:
*
args
,
log_requests
:
bool
=
True
,
**
kwargs
)
->
None
:
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
)
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs
=
True
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
,
use_cached_outputs
=
use_cached_outputs
)
self
.
log_requests
=
log_requests
self
.
use_async_sockets
=
use_async_sockets
...
...
vllm/outputs.py
View file @
1a2aef3e
...
...
@@ -114,17 +114,28 @@ class RequestOutput:
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
@
classmethod
def
from_seq_group
(
cls
,
se
q_group
:
SequenceGroup
)
->
Optional
[
"RequestOutput"
]:
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
,
u
se
_cache
:
bool
)
->
Optional
[
"RequestOutput"
]:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
is
None
:
raise
ValueError
(
"Sampling parameters are missing for a CompletionRequest."
)
finished
=
seq_group
.
is_finished
()
if
sampling_params
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
and
(
not
finished
):
return
None
# Init cache (if needed)
if
use_cache
and
seq_group
.
cached_request_output
is
None
:
seq_group
.
cached_request_output
=
RequestOutput
(
# type: ignore
request_id
=
""
,
prompt
=
None
,
prompt_token_ids
=
[],
prompt_logprobs
=
None
,
outputs
=
[],
finished
=
False
)
seqs
=
seq_group
.
get_seqs
()
if
len
(
seqs
)
==
1
:
top_n_seqs
=
seqs
...
...
@@ -149,29 +160,66 @@ class RequestOutput:
outputs
=
[]
include_prompt
=
True
for
seq
in
top_n_seqs
:
for
i
,
seq
in
enumerate
(
top_n_seqs
)
:
output_text
=
seq
.
get_output_text_to_return
(
text_buffer_length
,
delta
)
output_token_ids
=
seq
.
get_output_token_ids_to_return
(
delta
)
num_output_tokens
=
1
if
isinstance
(
output_token_ids
,
int
)
else
len
(
output_token_ids
)
output_logprobs
=
seq
.
output_logprobs
if
include_logprobs
else
None
if
delta
:
# Slice logprobs delta if applicable
if
output_logprobs
:
output_logprobs
=
output_logprobs
[
-
len
(
output_token
_ids
)
:]
output_logprobs
=
output_logprobs
[
-
num_
output_token
s
:]
# Don't include prompt if this is after the first output
# containing decode token ids
if
include_prompt
and
seq
.
get_output_len
()
>
len
(
output_token_ids
):
if
include_prompt
and
seq
.
get_output_len
()
>
num_output_tokens
:
include_prompt
=
False
outputs
.
append
(
CompletionOutput
(
seqs
.
index
(
seq
),
output_text
,
output_token_ids
,
if
use_cache
:
# Get cached output object
cached_outputs
=
seq_group
.
cached_request_output
.
outputs
# type: ignore
if
i
>=
len
(
cached_outputs
):
cached_outputs
.
append
(
CompletionOutput
(
index
=
i
,
text
=
""
,
token_ids
=
[],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
None
,
stop_reason
=
None
))
output
=
cached_outputs
[
i
]
# Init cached output object
assert
output
.
index
==
i
output
.
text
=
output_text
if
isinstance
(
output_token_ids
,
int
):
output
.
token_ids
.
clear
()
output
.
token_ids
.
append
(
output_token_ids
)
else
:
output
.
token_ids
=
output_token_ids
output
.
cumulative_logprob
=
seq
.
get_cumulative_logprob
()
\
if
include_logprobs
else
None
output
.
logprobs
=
output_logprobs
output
.
finish_reason
=
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
output
.
stop_reason
=
seq
.
stop_reason
else
:
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
output_text
,
[
output_token_ids
]
if
isinstance
(
output_token_ids
,
int
)
else
output_token_ids
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
output_logprobs
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
))
seq
.
stop_reason
)
outputs
.
append
(
output
)
# Every sequence in the sequence group should have the same prompt.
if
include_prompt
:
...
...
@@ -188,16 +236,20 @@ class RequestOutput:
prompt_logprobs
=
None
finished_time
=
time
.
time
()
if
finished
else
None
seq_group
.
set_finished_time
(
finished_time
)
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
prompt_logprobs
,
outputs
,
finished
,
seq_group
.
metrics
,
lora_request
=
seq_group
.
lora_request
,
encoder_prompt
=
encoder_prompt
,
encoder_prompt_token_ids
=
encoder_prompt_token_ids
)
init_args
=
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
prompt_logprobs
,
outputs
,
finished
,
seq_group
.
metrics
,
seq_group
.
lora_request
,
encoder_prompt
,
encoder_prompt_token_ids
)
if
use_cache
:
request_output
=
seq_group
.
cached_request_output
request_output
.
__init__
(
*
init_args
)
# type: ignore
else
:
request_output
=
cls
(
*
init_args
)
return
request_output
def
__repr__
(
self
)
->
str
:
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
...
...
@@ -261,10 +313,10 @@ class EmbeddingRequestOutput:
class
RequestOutputFactory
:
@
staticmethod
def
create
(
seq_group
):
def
create
(
seq_group
:
SequenceGroup
,
use_cache
:
bool
=
False
):
# Determine the type based on a condition, for example:
if
hasattr
(
seq_group
,
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
return
EmbeddingRequestOutput
.
from_seq_group
(
seq_group
)
else
:
return
RequestOutput
.
from_seq_group
(
seq_group
)
return
RequestOutput
.
from_seq_group
(
seq_group
,
use_cache
)
vllm/sequence.py
View file @
1a2aef3e
...
...
@@ -436,7 +436,7 @@ class Sequence:
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
# These are used to keep track of delta outputs
self
.
_last_token_ids_offset
:
int
=
0
self
.
_last_
output_
token_ids_offset
:
int
=
0
self
.
_last_output_text_offset
:
int
=
0
# Used for incremental detokenization
...
...
@@ -499,18 +499,26 @@ class Sequence:
return
self
.
output_text
[
last_offset
:
length
]
return
""
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
GenericSequence
[
int
]:
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
Union
[
GenericSequence
[
int
],
int
]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
get_output_token_ids
()
length
=
self
.
get_output_len
()
last_offset
=
self
.
_last_token_ids_offset
if
last_offset
<
length
:
self
.
_last_token_ids_offset
=
length
return
self
.
data
.
_output_token_ids
[
last_offset
:]
return
()
output_len
=
self
.
get_output_len
()
# Get the number of new tokens
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
self
.
_last_output_token_ids_offset
=
output_len
# Return new tokens
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
-
1
]
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:]
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
# TODO This can produce incorrect hash when block size > prompt size
...
...
@@ -671,6 +679,8 @@ class SequenceGroup:
self
.
encoder_seq
=
encoder_seq
self
.
trace_headers
=
trace_headers
self
.
cached_request_output
=
None
@
property
def
prompt
(
self
)
->
Optional
[
str
]:
# All sequences in the group should have the same prompt.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment