Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
551ce010
Unverified
Commit
551ce010
authored
Sep 12, 2024
by
Nick Hill
Committed by
GitHub
Sep 12, 2024
Browse files
[Core] Add engine option to return only deltas or final output (#7381)
parent
a6c0f365
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
371 additions
and
137 deletions
+371
-137
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+147
-14
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+13
-11
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+8
-15
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+6
-1
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+72
-53
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+19
-13
vllm/outputs.py
vllm/outputs.py
+55
-24
vllm/sampling_params.py
vllm/sampling_params.py
+16
-1
vllm/sequence.py
vllm/sequence.py
+34
-5
No files found.
.buildkite/test-pipeline.yaml
View file @
551ce010
...
...
@@ -50,6 +50,7 @@ steps:
-
tests/worker
commands
:
-
pytest -v -s async_engine
# Async Engine
-
NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
-
pytest -v -s test_inputs.py
-
pytest -v -s multimodal
-
pytest -v -s test_utils.py
# Utils
...
...
tests/async_engine/test_async_llm_engine.py
View file @
551ce010
import
asyncio
import
os
import
uuid
from
asyncio
import
CancelledError
from
copy
import
copy
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
List
,
Optional
import
pytest
import
pytest_asyncio
...
...
@@ -11,6 +14,7 @@ from vllm import SamplingParams
from
vllm.config
import
ParallelConfig
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.outputs
import
RequestOutput
as
RealRequestOutput
from
vllm.sampling_params
import
RequestOutputKind
from
..conftest
import
cleanup
from
..utils
import
wait_for_gpu_memory_to_clear
...
...
@@ -122,8 +126,17 @@ def start_engine():
timeout_s
=
60
,
)
num_scheduler_steps
=
int
(
os
.
getenv
(
"NUM_SCHEDULER_STEPS"
,
"1"
))
print
(
f
"Starting engine with num_scheduler_steps=
{
num_scheduler_steps
}
"
)
return
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
))
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
num_scheduler_steps
=
num_scheduler_steps
))
def
uid
()
->
str
:
return
str
(
uuid
.
uuid4
())
@
pytest_asyncio
.
fixture
(
scope
=
"module"
)
...
...
@@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_asyncio_run
(
async_engine
):
scheduler_config
=
await
async_engine
.
get_scheduler_config
()
num_scheduler_steps
=
scheduler_config
.
num_scheduler_steps
async
def
run
(
prompt
:
str
):
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
,
min_tokens
=
32
,
)
output_count
=
0
final_output
=
None
async
for
output
in
async_engine
.
generate
(
prompt
,
sampling_params
,
request_id
=
prompt
):
request_id
=
uid
()):
output_count
+=
1
final_output
=
output
return
final_output
return
final_output
,
output_count
results
=
await
asyncio
.
gather
(
run
(
"test0"
),
run
(
"test
1
"
),
run
(
"test
0
"
),
)
assert
len
(
results
)
==
2
first
,
second
=
results
# remove nondeterministic fields for comparison
first
[
0
].
metrics
=
None
second
[
0
].
metrics
=
None
first
[
0
].
request_id
=
None
second
[
0
].
request_id
=
None
assert
str
(
first
)
==
str
(
second
)
output_count
=
results
[
0
][
1
]
if
num_scheduler_steps
==
1
:
assert
output_count
==
32
else
:
assert
1
<
output_count
<
32
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_output_kinds
(
async_engine
):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""
scheduler_config
=
await
async_engine
.
get_scheduler_config
()
num_scheduler_steps
=
scheduler_config
.
num_scheduler_steps
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
,
min_tokens
=
32
,
)
async
def
run
(
prompt
:
str
,
kind
:
RequestOutputKind
):
params
=
copy
(
sampling_params
)
params
.
output_kind
=
kind
output_count
=
0
final_output
=
None
async
for
output
in
async_engine
.
generate
(
prompt
,
params
,
request_id
=
uid
()):
output_count
+=
1
final_output
=
output
assert
final_output
is
not
None
return
(
final_output
.
prompt_token_ids
,
final_output
.
outputs
[
0
].
token_ids
,
final_output
.
outputs
[
0
].
text
,
output_count
)
async
def
run_deltas
(
prompt
:
str
):
params
=
copy
(
sampling_params
)
params
.
output_kind
=
RequestOutputKind
.
DELTA
prompt_tokens
=
None
output_tokens
:
List
[
int
]
=
[]
output_text
=
""
output_count
=
0
async
for
output
in
async_engine
.
generate
(
prompt
,
params
,
request_id
=
uid
()):
token_ids
=
output
.
outputs
[
0
].
token_ids
text
=
output
.
outputs
[
0
].
text
# Ensure we get prompt ids iff we haven't yet received output tokens
if
output_tokens
:
assert
1
<=
len
(
token_ids
)
<=
num_scheduler_steps
assert
text
assert
not
output
.
prompt_token_ids
else
:
assert
output
.
prompt_token_ids
prompt_tokens
=
output
.
prompt_token_ids
output_tokens
.
extend
(
token_ids
)
output_text
+=
text
output_count
+=
1
return
prompt_tokens
,
output_tokens
,
output_text
,
output_count
results
=
await
asyncio
.
gather
(
run
(
"common input prompt"
,
RequestOutputKind
.
CUMULATIVE
),
run
(
"common input prompt"
,
RequestOutputKind
.
FINAL_ONLY
),
run_deltas
(
"common input prompt"
))
# Make sure outputs are the same
prompt_set
=
set
(
tuple
(
prompt_ids
)
for
prompt_ids
,
_
,
_
,
_
in
results
)
assert
len
(
prompt_set
)
==
1
text_set
=
set
(
text
for
_
,
_
,
text
,
_
in
results
)
assert
len
(
text_set
)
==
1
tokens_set
=
set
(
tuple
(
ids
)
for
_
,
ids
,
_
,
_
in
results
)
assert
len
(
tokens_set
)
==
1
cumulative
,
final
,
deltas
=
results
# output message counts
assert
cumulative
[
3
]
==
deltas
[
3
]
if
num_scheduler_steps
==
1
:
assert
cumulative
[
3
]
==
32
else
:
assert
1
<
cumulative
[
3
]
<
32
assert
final
[
3
]
==
1
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_cancellation
(
async_engine
):
scheduler_config
=
await
async_engine
.
get_scheduler_config
()
num_scheduler_steps
=
scheduler_config
.
num_scheduler_steps
sampling_params
=
SamplingParams
(
temperature
=
0
,
min_tokens
=
1
0
,
max_tokens
=
1
0
,
min_tokens
=
1
3
,
max_tokens
=
1
3
,
)
stop_at
=
5
if
num_scheduler_steps
==
1
else
1
request_id
=
uid
()
i
=
0
with
pytest
.
raises
(
CancelledError
):
async
for
output
in
async_engine
.
generate
(
"test2"
,
sampling_params
,
request_id
=
"test2"
):
request_id
=
request_id
):
assert
not
output
.
finished
i
+=
1
if
i
==
5
:
await
async_engine
.
abort
(
"test2"
)
if
i
==
stop_at
:
await
async_engine
.
abort
(
request_id
)
assert
i
==
5
assert
i
==
stop_at
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_delayed_generator
(
async_engine
):
scheduler_config
=
await
async_engine
.
get_scheduler_config
()
if
scheduler_config
.
num_scheduler_steps
!=
1
:
pytest
.
skip
(
"no need to test this one with multistep"
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
min_tokens
=
10
,
max_tokens
=
10
,
)
stream
=
async_engine
.
generate
(
"test3"
,
sampling_params
,
request_id
=
"test3"
)
stream
=
async_engine
.
generate
(
"test3"
,
sampling_params
,
request_id
=
uid
())
i
=
0
final_output
:
Optional
[
RealRequestOutput
]
=
None
async
for
output
in
stream
:
...
...
vllm/engine/llm_engine.py
View file @
551ce010
...
...
@@ -39,7 +39,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
...
...
@@ -225,9 +225,6 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only
:
bool
=
False
,
)
->
None
:
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
...
...
@@ -295,7 +292,6 @@ class LLMEngine:
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
self
.
log_stats
=
log_stats
self
.
step_return_finished_only
=
step_return_finished_only
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
...
@@ -1273,7 +1269,7 @@ class LLMEngine:
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
"""
now
=
time
.
time
()
...
...
@@ -1378,7 +1374,8 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
...
...
@@ -1415,14 +1412,19 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
if
(
seq_group
.
is_finished
()
if
self
.
step_return_finished_only
else
True
):
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
params
=
seq_group
.
sampling_params
if
params
is
not
None
and
params
.
output_kind
==
(
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
# Immediately process request outputs here (if callback is given)
if
(
ctx
.
request_outputs
...
...
vllm/entrypoints/llm.py
View file @
551ce010
...
...
@@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
...
...
@@ -642,14 +642,12 @@ class LLM:
raise
ValueError
(
"The lengths of prompts and lora_request "
"must be the same."
)
if
isinstance
(
params
,
list
):
params
=
[
self
.
_add_guided_processor
(
param
,
guided_options
)
if
isinstance
(
param
,
SamplingParams
)
else
param
for
param
in
params
]
elif
isinstance
(
params
,
SamplingParams
):
params
=
self
.
_add_guided_processor
(
params
,
guided_options
)
for
sp
in
params
if
isinstance
(
params
,
list
)
else
(
params
,
):
if
isinstance
(
sp
,
SamplingParams
):
self
.
_add_guided_processor
(
sp
,
guided_options
)
# We only care about the final output
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inputs
):
...
...
@@ -709,9 +707,6 @@ class LLM:
f
"output:
{
0
:.
2
f
}
toks/s"
),
)
# In the loop below, only finished outputs are used
self
.
llm_engine
.
step_return_finished_only
=
True
# Run the engine.
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
total_in_toks
=
0
...
...
@@ -724,6 +719,7 @@ class LLM:
if
use_tqdm
:
if
isinstance
(
output
,
RequestOutput
):
# Calculate tokens only for RequestOutput
assert
output
.
prompt_token_ids
is
not
None
total_in_toks
+=
len
(
output
.
prompt_token_ids
)
in_spd
=
total_in_toks
/
pbar
.
format_dict
[
"elapsed"
]
total_out_toks
+=
sum
(
...
...
@@ -735,9 +731,6 @@ class LLM:
f
"output:
{
out_spd
:.
2
f
}
toks/s"
)
pbar
.
update
(
1
)
# Restore original behavior
self
.
llm_engine
.
step_return_finished_only
=
False
if
use_tqdm
:
pbar
.
close
()
# Sort the outputs by request ID.
...
...
vllm/entrypoints/openai/protocol.py
View file @
551ce010
...
...
@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sampling_params
import
(
LogitsProcessor
,
RequestOutputKind
,
SamplingParams
)
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
random_uuid
...
...
@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
,
)
@
model_validator
(
mode
=
"before"
)
...
...
@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
,
)
@
model_validator
(
mode
=
"before"
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
551ce010
...
...
@@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
def
get_chat_request_role
(
self
,
request
:
ChatCompletionRequest
)
->
str
:
if
request
.
add_generation_prompt
:
return
self
.
response_role
else
:
return
request
.
messages
[
-
1
][
"role"
]
return
request
.
messages
[
-
1
][
"role"
]
async
def
chat_completion_stream_generator
(
self
,
...
...
@@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
# Send response for each token for each request.n (index)
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_texts
=
[
""
]
*
num_choices
previous_num_tokens
=
[
0
]
*
num_choices
finish_reason_sent
=
[
False
]
*
num_choices
num_prompt_tokens
=
0
tool_parser
:
Optional
[
ToolParser
]
=
self
.
tool_parser
(
tokenizer
)
if
self
.
tool_parser
else
None
if
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
tool_choice_function_name
=
request
.
tool_choice
.
function
.
name
else
:
tool_choice_function_name
=
None
# Determine whether tools are in use with "auto" tool choice
tool_choice_auto
=
(
not
tool_choice_function_name
and
self
.
_should_stream_with_auto_tool_parsing
(
request
))
all_previous_token_ids
:
Optional
[
List
[
List
[
int
]]]
if
tool_choice_auto
:
# These are only required in "auto" tool choice case
previous_texts
=
[
""
]
*
num_choices
all_previous_token_ids
=
[[]]
*
num_choices
else
:
previous_texts
,
all_previous_token_ids
=
None
,
None
try
:
async
for
res
in
result_generator
:
if
res
.
prompt_token_ids
is
not
None
:
num_prompt_tokens
=
len
(
res
.
prompt_token_ids
)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
...
...
@@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
and
request
.
stream_options
.
include_usage
):
# if continuous usage stats are requested, add it
if
request
.
stream_options
.
continuous_usage_stats
:
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
usage
=
UsageInfo
(
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
num_
prompt_tokens
)
chunk
.
usage
=
usage
# otherwise don't
else
:
...
...
@@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
total_tokens
=
num_
prompt_tokens
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
...
...
@@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
first_iteration
=
False
for
output
in
res
.
outputs
:
i
=
output
.
index
if
finish_reason_sent
[
i
]:
continue
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
out_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
and
request
.
top_logprobs
is
not
None
:
assert
out
_
logprobs
is
not
None
,
(
assert
out
put
.
logprobs
is
not
None
,
(
"Did not output logprobs"
)
logprobs
=
self
.
_create_chat_logprobs
(
token_ids
=
delta_
token_ids
,
top_logprobs
=
out
_
logprobs
,
token_ids
=
output
.
token_ids
,
top_logprobs
=
out
put
.
logprobs
,
tokenizer
=
tokenizer
,
num_output_top_logprobs
=
request
.
top_logprobs
,
)
else
:
logprobs
=
None
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_message
:
Optional
[
DeltaMessage
]
=
None
delta_text
=
output
.
text
delta_message
:
Optional
[
DeltaMessage
]
# handle streaming deltas for tools with named tool_choice
if
(
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
):
if
tool_choice_function_name
:
delta_message
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
function
=
DeltaFunctionCall
(
name
=
request
.
tool_choice
.
function
.
name
,
name
=
tool_choice
_
function
_
name
,
arguments
=
delta_text
),
index
=
i
)
])
# handle streaming deltas for tools with "auto" tool choice
elif
(
self
.
_should_stream_with_auto_tool_parsing
(
request
)
and
tool_parser
):
elif
tool_choice_auto
:
assert
previous_texts
is
not
None
assert
all_previous_token_ids
is
not
None
assert
tool_parser
is
not
None
#TODO optimize manipulation of these lists
previous_text
=
previous_texts
[
i
]
previous_token_ids
=
all_previous_token_ids
[
i
]
current_text
=
previous_text
+
delta_text
current_token_ids
=
previous_token_ids
+
list
(
output
.
token_ids
)
delta_message
=
(
tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
s
[
i
]
,
current_text
=
output
.
text
,
previous_text
=
previous_text
,
current_text
=
current_
text
,
delta_text
=
delta_text
,
previous_token_ids
=
\
output
.
token_ids
[
:
-
1
*
len
(
delta_token_ids
)
],
current_token_ids
=
output
.
token_ids
,
delta_token_ids
=
delta_token_ids
)
)
previous_token_ids
=
previous_token_ids
,
current_token_ids
=
current_token_ids
,
delta_token_ids
=
output
.
token_ids
))
# update the previous values for the next iteration
previous_texts
[
i
]
=
current_text
all_previous_token_ids
[
i
]
=
current_token_ids
# handle streaming just a content delta
else
:
delta_message
=
DeltaMessage
(
content
=
delta_text
)
# set the previous values for the next iteration
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
previous_num_tokens
[
i
]
+=
len
(
output
.
token_ids
)
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
...
...
@@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
# handle usage stats if requested & if continuous
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
if
request
.
stream_options
.
continuous_usage_stats
:
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
total_tokens
=
num_
prompt_tokens
+
completion_tokens
,
)
chunk
.
usage
=
usage
...
...
@@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser
.
prev_tool_call_arr
[
index
].
get
(
"arguments"
,
{}))
# get what we've streamed so f
o
r for arguments
# get what we've streamed so f
a
r for arguments
# for the current tool
actual_call
=
tool_parser
.
streamed_args_for_tool
[
index
]
...
...
@@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
])
# Send the finish response for each request.n only once
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
delta_message
,
...
...
@@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
if
request
.
stream_options
.
continuous_usage_stats
:
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
total_tokens
=
num_
prompt_tokens
+
completion_tokens
,
)
chunk
.
usage
=
usage
...
...
@@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
completion_tokens
=
previous_num_tokens
[
i
]
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
previous_num
_tokens
[
i
]
,
total_tokens
=
prompt_tokens
+
previous_num
_tokens
[
i
]
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
completion
_tokens
,
total_tokens
=
num_
prompt_tokens
+
completion
_tokens
,
)
final_usage_chunk
=
ChatCompletionStreamResponse
(
...
...
@@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
or
""
)
choice
.
message
.
content
=
full_message
assert
final_res
.
prompt_token_ids
is
not
None
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
...
...
@@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
return
bool
(
# if there is a delta message that includes tool calls which
# include a function that has arguments
self
.
enable_auto_tools
and
self
.
tool_parser
and
delta_message
output
.
finish_reason
is
not
None
and
self
.
enable_auto_tools
and
self
.
tool_parser
and
delta_message
and
delta_message
.
tool_calls
and
delta_message
.
tool_calls
[
0
]
and
delta_message
.
tool_calls
[
0
].
function
and
delta_message
.
tool_calls
[
0
].
function
.
arguments
is
not
None
and
output
.
finish_reason
is
not
None
)
vllm/entrypoints/openai/serving_completion.py
View file @
551ce010
...
...
@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer
:
AnyTokenizer
,
)
->
AsyncGenerator
[
str
,
None
]:
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_texts
=
[
""
]
*
num_choices
*
num_prompts
previous_text
_len
s
=
[
0
]
*
num_choices
*
num_prompts
previous_num_tokens
=
[
0
]
*
num_choices
*
num_prompts
has_echoed
=
[
False
]
*
num_choices
*
num_prompts
num_prompt_tokens
=
[
0
]
*
num_prompts
try
:
async
for
prompt_idx
,
res
in
result_generator
:
...
...
@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs
=
res
.
prompt_logprobs
prompt_text
=
res
.
prompt
# Prompt details are excluded from later streamed outputs
if
res
.
prompt_token_ids
is
not
None
:
num_prompt_tokens
[
prompt_idx
]
=
len
(
res
.
prompt_token_ids
)
delta_token_ids
:
GenericSequence
[
int
]
out_logprobs
:
Optional
[
GenericSequence
[
Optional
[
Dict
[
int
,
Logprob
]]]]
...
...
@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert
request
.
max_tokens
is
not
None
if
request
.
echo
and
request
.
max_tokens
==
0
:
assert
prompt_token_ids
is
not
None
assert
prompt_text
is
not
None
# only return the prompt
delta_text
=
prompt_text
...
...
@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed
[
i
]
=
True
elif
(
request
.
echo
and
request
.
max_tokens
>
0
and
not
has_echoed
[
i
]):
assert
prompt_token_ids
is
not
None
assert
prompt_text
is
not
None
assert
prompt_logprobs
is
not
None
# echo the prompt and first token
...
...
@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed
[
i
]
=
True
else
:
# return just the delta
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
out_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
delta_text
=
output
.
text
delta_token_ids
=
output
.
token_ids
out_logprobs
=
output
.
logprobs
if
request
.
logprobs
is
not
None
:
assert
out_logprobs
is
not
None
,
(
...
...
@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs
=
out_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
tokenizer
=
tokenizer
,
initial_text_offset
=
len
(
previous_texts
[
i
]
)
,
initial_text_offset
=
previous_text
_len
s
[
i
],
)
else
:
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
previous_text
_len
s
[
i
]
+
=
len
(
output
.
text
)
previous_num_tokens
[
i
]
+
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
stop_reason
=
output
.
stop_reason
...
...
@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
or
output
.
finish_reason
is
not
None
):
prompt_tokens
=
len
(
prompt_token_id
s
)
completion_tokens
=
len
(
output
.
token_ids
)
prompt_tokens
=
num_
prompt_token
s
[
prompt
_id
x
]
completion_tokens
=
previous_num_tokens
[
i
]
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
...
...
@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
for
final_res
in
final_res_batch
:
prompt_token_ids
=
final_res
.
prompt_token_ids
assert
prompt_token_ids
is
not
None
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
...
...
@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
)
choices
.
append
(
choice_data
)
num_generated_tokens
+=
len
(
output
.
token_ids
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
num_generated_tokens
+=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
...
...
vllm/outputs.py
View file @
551ce010
...
...
@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from
typing
import
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
SequenceGroup
,
SequenceStatus
)
...
...
@@ -92,7 +93,7 @@ class RequestOutput:
self
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
Optional
[
List
[
int
]
]
,
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
finished
:
bool
,
...
...
@@ -113,19 +114,26 @@ class RequestOutput:
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
if
seq_group
.
sampling_params
is
None
:
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
Optional
[
"RequestOutput"
]:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
is
None
:
raise
ValueError
(
"Sampling parameters are missing for a CompletionRequest."
)
finished
=
seq_group
.
is_finished
()
if
sampling_params
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
and
(
not
finished
):
return
None
seqs
=
seq_group
.
get_seqs
()
if
len
(
seqs
)
==
1
:
top_n_seqs
=
seqs
else
:
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
if
seq_group
.
sampling_params
.
use_beam_search
:
n
=
sampling_params
.
n
if
sampling_params
.
use_beam_search
:
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
seq_group
.
sampling_params
.
length_penalty
)
sampling_params
.
length_penalty
)
else
:
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
...
...
@@ -135,26 +143,49 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
outputs
=
[
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
data
.
_output_token_ids
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
seq
.
output_logprobs
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
)
for
seq
in
top_n_seqs
]
include_logprobs
=
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
sampling_params
.
output_text_buffer_length
delta
=
sampling_params
.
output_kind
==
RequestOutputKind
.
DELTA
outputs
=
[]
include_prompt
=
True
for
seq
in
top_n_seqs
:
output_text
=
seq
.
get_output_text_to_return
(
text_buffer_length
,
delta
)
output_token_ids
=
seq
.
get_output_token_ids_to_return
(
delta
)
output_logprobs
=
seq
.
output_logprobs
if
include_logprobs
else
None
if
delta
:
# Slice logprobs delta if applicable
if
output_logprobs
:
output_logprobs
=
output_logprobs
[
-
len
(
output_token_ids
):]
# Don't include prompt if this is after the first output
# containing decode token ids
if
include_prompt
and
seq
.
get_output_len
()
>
len
(
output_token_ids
):
include_prompt
=
False
outputs
.
append
(
CompletionOutput
(
seqs
.
index
(
seq
),
output_text
,
output_token_ids
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
output_logprobs
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
))
# Every sequence in the sequence group should have the same prompt.
prompt
=
seq_group
.
prompt
prompt_token_ids
=
seq_group
.
prompt_token_ids
encoder_prompt
=
seq_group
.
encoder_prompt
encoder_prompt_token_ids
=
seq_group
.
encoder_prompt_token_ids
prompt_logprobs
=
seq_group
.
prompt_logprobs
finished
=
seq_group
.
is_finished
()
if
include_prompt
:
prompt
=
seq_group
.
prompt
prompt_token_ids
=
seq_group
.
prompt_token_ids
encoder_prompt
=
seq_group
.
encoder_prompt
encoder_prompt_token_ids
=
seq_group
.
encoder_prompt_token_ids
prompt_logprobs
=
seq_group
.
prompt_logprobs
else
:
prompt
=
None
prompt_token_ids
=
None
encoder_prompt
=
None
encoder_prompt_token_ids
=
None
prompt_logprobs
=
None
finished_time
=
time
.
time
()
if
finished
else
None
seq_group
.
set_finished_time
(
finished_time
)
return
cls
(
seq_group
.
request_id
,
...
...
vllm/sampling_params.py
View file @
551ce010
"""Sampling parameters for text generation."""
import
copy
from
enum
import
IntEnum
from
enum
import
Enum
,
IntEnum
from
functools
import
cached_property
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Union
...
...
@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
to sample from."""
class
RequestOutputKind
(
Enum
):
# Return entire output so far in every RequestOutput
CUMULATIVE
=
0
# Return only deltas in each RequestOutput
DELTA
=
1
# Do not return intermediate RequestOuputs
FINAL_ONLY
=
2
class
SamplingParams
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
...
...
@@ -147,6 +156,7 @@ class SamplingParams(
logits_processors
:
Optional
[
Any
]
=
None
include_stop_str_in_output
:
bool
=
False
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
# The below fields are not supposed to be used as an input.
# They are set in post_init.
...
...
@@ -182,6 +192,7 @@ class SamplingParams(
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
,
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
,
)
->
"SamplingParams"
:
return
SamplingParams
(
n
=
1
if
n
is
None
else
n
,
...
...
@@ -213,6 +224,7 @@ class SamplingParams(
spaces_between_special_tokens
=
spaces_between_special_tokens
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
output_kind
=
output_kind
,
)
def
__post_init__
(
self
)
->
None
:
...
...
@@ -317,6 +329,9 @@ class SamplingParams(
raise
ValueError
(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
)
if
self
.
best_of
!=
self
.
n
and
self
.
output_kind
==
(
RequestOutputKind
.
DELTA
):
raise
ValueError
(
"best_of must equal n to use output_kind=DELTA"
)
def
_verify_beam_search
(
self
)
->
None
:
if
self
.
best_of
==
1
:
...
...
vllm/sequence.py
View file @
551ce010
...
...
@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from
array
import
array
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
,
cast
)
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
,
cast
import
msgspec
import
torch
...
...
@@ -407,6 +408,10 @@ class Sequence:
self
.
status
=
SequenceStatus
.
WAITING
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
# These are used to keep track of delta outputs
self
.
_last_token_ids_offset
:
int
=
0
self
.
_last_output_text_offset
:
int
=
0
# Used for incremental detokenization
self
.
prefix_offset
=
0
self
.
read_offset
=
0
...
...
@@ -462,11 +467,35 @@ class Sequence:
return
self
.
prompt_adapter_request
.
prompt_adapter_id
\
if
self
.
prompt_adapter_request
else
0
def
get_output_text_to_return
(
self
,
buffer_length
:
int
):
def
get_output_text_to_return
(
self
,
buffer_length
:
int
,
delta
:
bool
)
->
str
:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
truncate
=
buffer_length
and
not
self
.
is_finished
()
return
self
.
output_text
[:
-
buffer_length
]
if
truncate
else
(
self
.
output_text
)
if
not
delta
:
return
self
.
output_text
[:
-
buffer_length
]
if
truncate
else
(
self
.
output_text
)
length
=
len
(
self
.
output_text
)
-
buffer_length
last_offset
=
self
.
_last_output_text_offset
if
last_offset
<
length
:
self
.
_last_output_text_offset
=
length
return
self
.
output_text
[
last_offset
:
length
]
return
""
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
GenericSequence
[
int
]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
get_output_token_ids
()
length
=
self
.
get_output_len
()
last_offset
=
self
.
_last_token_ids_offset
if
last_offset
<
length
:
self
.
_last_token_ids_offset
=
length
return
self
.
data
.
_output_token_ids
[
last_offset
:]
return
()
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
# TODO This can produce incorrect hash when block size > prompt size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment