Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
551ce010
Unverified
Commit
551ce010
authored
Sep 12, 2024
by
Nick Hill
Committed by
GitHub
Sep 12, 2024
Browse files
[Core] Add engine option to return only deltas or final output (#7381)
parent
a6c0f365
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
371 additions
and
137 deletions
+371
-137
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+147
-14
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+13
-11
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+8
-15
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+6
-1
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+72
-53
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+19
-13
vllm/outputs.py
vllm/outputs.py
+55
-24
vllm/sampling_params.py
vllm/sampling_params.py
+16
-1
vllm/sequence.py
vllm/sequence.py
+34
-5
No files found.
.buildkite/test-pipeline.yaml
View file @
551ce010
...
@@ -50,6 +50,7 @@ steps:
...
@@ -50,6 +50,7 @@ steps:
-
tests/worker
-
tests/worker
commands
:
commands
:
-
pytest -v -s async_engine
# Async Engine
-
pytest -v -s async_engine
# Async Engine
-
NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
-
pytest -v -s test_inputs.py
-
pytest -v -s test_inputs.py
-
pytest -v -s multimodal
-
pytest -v -s multimodal
-
pytest -v -s test_utils.py
# Utils
-
pytest -v -s test_utils.py
# Utils
...
...
tests/async_engine/test_async_llm_engine.py
View file @
551ce010
import
asyncio
import
asyncio
import
os
import
uuid
from
asyncio
import
CancelledError
from
asyncio
import
CancelledError
from
copy
import
copy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
List
,
Optional
import
pytest
import
pytest
import
pytest_asyncio
import
pytest_asyncio
...
@@ -11,6 +14,7 @@ from vllm import SamplingParams
...
@@ -11,6 +14,7 @@ from vllm import SamplingParams
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.outputs
import
RequestOutput
as
RealRequestOutput
from
vllm.outputs
import
RequestOutput
as
RealRequestOutput
from
vllm.sampling_params
import
RequestOutputKind
from
..conftest
import
cleanup
from
..conftest
import
cleanup
from
..utils
import
wait_for_gpu_memory_to_clear
from
..utils
import
wait_for_gpu_memory_to_clear
...
@@ -122,8 +126,17 @@ def start_engine():
...
@@ -122,8 +126,17 @@ def start_engine():
timeout_s
=
60
,
timeout_s
=
60
,
)
)
num_scheduler_steps
=
int
(
os
.
getenv
(
"NUM_SCHEDULER_STEPS"
,
"1"
))
print
(
f
"Starting engine with num_scheduler_steps=
{
num_scheduler_steps
}
"
)
return
AsyncLLMEngine
.
from_engine_args
(
return
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
))
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
num_scheduler_steps
=
num_scheduler_steps
))
def
uid
()
->
str
:
return
str
(
uuid
.
uuid4
())
@
pytest_asyncio
.
fixture
(
scope
=
"module"
)
@
pytest_asyncio
.
fixture
(
scope
=
"module"
)
...
@@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
...
@@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_asyncio_run
(
async_engine
):
async
def
test_asyncio_run
(
async_engine
):
scheduler_config
=
await
async_engine
.
get_scheduler_config
()
num_scheduler_steps
=
scheduler_config
.
num_scheduler_steps
async
def
run
(
prompt
:
str
):
async
def
run
(
prompt
:
str
):
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
temperature
=
0
,
max_tokens
=
32
,
max_tokens
=
32
,
min_tokens
=
32
,
)
)
output_count
=
0
final_output
=
None
async
for
output
in
async_engine
.
generate
(
prompt
,
async
for
output
in
async_engine
.
generate
(
prompt
,
sampling_params
,
sampling_params
,
request_id
=
prompt
):
request_id
=
uid
()):
output_count
+=
1
final_output
=
output
final_output
=
output
return
final_output
return
final_output
,
output_count
results
=
await
asyncio
.
gather
(
results
=
await
asyncio
.
gather
(
run
(
"test0"
),
run
(
"test0"
),
run
(
"test
1
"
),
run
(
"test
0
"
),
)
)
assert
len
(
results
)
==
2
assert
len
(
results
)
==
2
first
,
second
=
results
# remove nondeterministic fields for comparison
first
[
0
].
metrics
=
None
second
[
0
].
metrics
=
None
first
[
0
].
request_id
=
None
second
[
0
].
request_id
=
None
assert
str
(
first
)
==
str
(
second
)
output_count
=
results
[
0
][
1
]
if
num_scheduler_steps
==
1
:
assert
output_count
==
32
else
:
assert
1
<
output_count
<
32
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_output_kinds
(
async_engine
):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""
scheduler_config
=
await
async_engine
.
get_scheduler_config
()
num_scheduler_steps
=
scheduler_config
.
num_scheduler_steps
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
,
min_tokens
=
32
,
)
async
def
run
(
prompt
:
str
,
kind
:
RequestOutputKind
):
params
=
copy
(
sampling_params
)
params
.
output_kind
=
kind
output_count
=
0
final_output
=
None
async
for
output
in
async_engine
.
generate
(
prompt
,
params
,
request_id
=
uid
()):
output_count
+=
1
final_output
=
output
assert
final_output
is
not
None
return
(
final_output
.
prompt_token_ids
,
final_output
.
outputs
[
0
].
token_ids
,
final_output
.
outputs
[
0
].
text
,
output_count
)
async
def
run_deltas
(
prompt
:
str
):
params
=
copy
(
sampling_params
)
params
.
output_kind
=
RequestOutputKind
.
DELTA
prompt_tokens
=
None
output_tokens
:
List
[
int
]
=
[]
output_text
=
""
output_count
=
0
async
for
output
in
async_engine
.
generate
(
prompt
,
params
,
request_id
=
uid
()):
token_ids
=
output
.
outputs
[
0
].
token_ids
text
=
output
.
outputs
[
0
].
text
# Ensure we get prompt ids iff we haven't yet received output tokens
if
output_tokens
:
assert
1
<=
len
(
token_ids
)
<=
num_scheduler_steps
assert
text
assert
not
output
.
prompt_token_ids
else
:
assert
output
.
prompt_token_ids
prompt_tokens
=
output
.
prompt_token_ids
output_tokens
.
extend
(
token_ids
)
output_text
+=
text
output_count
+=
1
return
prompt_tokens
,
output_tokens
,
output_text
,
output_count
results
=
await
asyncio
.
gather
(
run
(
"common input prompt"
,
RequestOutputKind
.
CUMULATIVE
),
run
(
"common input prompt"
,
RequestOutputKind
.
FINAL_ONLY
),
run_deltas
(
"common input prompt"
))
# Make sure outputs are the same
prompt_set
=
set
(
tuple
(
prompt_ids
)
for
prompt_ids
,
_
,
_
,
_
in
results
)
assert
len
(
prompt_set
)
==
1
text_set
=
set
(
text
for
_
,
_
,
text
,
_
in
results
)
assert
len
(
text_set
)
==
1
tokens_set
=
set
(
tuple
(
ids
)
for
_
,
ids
,
_
,
_
in
results
)
assert
len
(
tokens_set
)
==
1
cumulative
,
final
,
deltas
=
results
# output message counts
assert
cumulative
[
3
]
==
deltas
[
3
]
if
num_scheduler_steps
==
1
:
assert
cumulative
[
3
]
==
32
else
:
assert
1
<
cumulative
[
3
]
<
32
assert
final
[
3
]
==
1
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_cancellation
(
async_engine
):
async
def
test_cancellation
(
async_engine
):
scheduler_config
=
await
async_engine
.
get_scheduler_config
()
num_scheduler_steps
=
scheduler_config
.
num_scheduler_steps
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
temperature
=
0
,
min_tokens
=
1
0
,
min_tokens
=
1
3
,
max_tokens
=
1
0
,
max_tokens
=
1
3
,
)
)
stop_at
=
5
if
num_scheduler_steps
==
1
else
1
request_id
=
uid
()
i
=
0
i
=
0
with
pytest
.
raises
(
CancelledError
):
with
pytest
.
raises
(
CancelledError
):
async
for
output
in
async_engine
.
generate
(
"test2"
,
async
for
output
in
async_engine
.
generate
(
"test2"
,
sampling_params
,
sampling_params
,
request_id
=
"test2"
):
request_id
=
request_id
):
assert
not
output
.
finished
assert
not
output
.
finished
i
+=
1
i
+=
1
if
i
==
5
:
if
i
==
stop_at
:
await
async_engine
.
abort
(
"test2"
)
await
async_engine
.
abort
(
request_id
)
assert
i
==
5
assert
i
==
stop_at
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_delayed_generator
(
async_engine
):
async
def
test_delayed_generator
(
async_engine
):
scheduler_config
=
await
async_engine
.
get_scheduler_config
()
if
scheduler_config
.
num_scheduler_steps
!=
1
:
pytest
.
skip
(
"no need to test this one with multistep"
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
temperature
=
0
,
min_tokens
=
10
,
min_tokens
=
10
,
max_tokens
=
10
,
max_tokens
=
10
,
)
)
stream
=
async_engine
.
generate
(
"test3"
,
stream
=
async_engine
.
generate
(
"test3"
,
sampling_params
,
request_id
=
uid
())
sampling_params
,
request_id
=
"test3"
)
i
=
0
i
=
0
final_output
:
Optional
[
RealRequestOutput
]
=
None
final_output
:
Optional
[
RealRequestOutput
]
=
None
async
for
output
in
stream
:
async
for
output
in
stream
:
...
...
vllm/engine/llm_engine.py
View file @
551ce010
...
@@ -39,7 +39,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
...
@@ -39,7 +39,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory
)
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceStatus
)
...
@@ -225,9 +225,6 @@ class LLMEngine:
...
@@ -225,9 +225,6 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only
:
bool
=
False
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
"Initializing an LLM engine (v%s) with config: "
...
@@ -295,7 +292,6 @@ class LLMEngine:
...
@@ -295,7 +292,6 @@ class LLMEngine:
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
self
.
observability_config
=
observability_config
or
ObservabilityConfig
(
)
)
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
step_return_finished_only
=
step_return_finished_only
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
@@ -1273,7 +1269,7 @@ class LLMEngine:
...
@@ -1273,7 +1269,7 @@ class LLMEngine:
ctx: The virtual engine context to work on
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
request_id: If provided, then only this request is going to be processed
"""
"""
now
=
time
.
time
()
now
=
time
.
time
()
...
@@ -1378,7 +1374,8 @@ class LLMEngine:
...
@@ -1378,7 +1374,8 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
# When we process a single request, we skip it for the next time,
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
# and invoke the request output callback (if there was final output)
...
@@ -1415,14 +1412,19 @@ class LLMEngine:
...
@@ -1415,14 +1412,19 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
if
(
seq_group
.
is_finished
()
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
if
self
.
step_return_finished_only
else
True
):
if
request_output
:
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
params
=
seq_group
.
sampling_params
if
params
is
not
None
and
params
.
output_kind
==
(
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
)
ctx
.
request_outputs
.
append
(
request_output
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
# Immediately process request outputs here (if callback is given)
# Immediately process request outputs here (if callback is given)
if
(
ctx
.
request_outputs
if
(
ctx
.
request_outputs
...
...
vllm/entrypoints/llm.py
View file @
551ce010
...
@@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
...
@@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
get_cached_tokenizer
)
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
...
@@ -642,14 +642,12 @@ class LLM:
...
@@ -642,14 +642,12 @@ class LLM:
raise
ValueError
(
"The lengths of prompts and lora_request "
raise
ValueError
(
"The lengths of prompts and lora_request "
"must be the same."
)
"must be the same."
)
if
isinstance
(
params
,
list
):
for
sp
in
params
if
isinstance
(
params
,
list
)
else
(
params
,
):
params
=
[
if
isinstance
(
sp
,
SamplingParams
):
self
.
_add_guided_processor
(
param
,
guided_options
)
self
.
_add_guided_processor
(
sp
,
guided_options
)
if
isinstance
(
param
,
SamplingParams
)
else
param
for
param
in
params
# We only care about the final output
]
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
elif
isinstance
(
params
,
SamplingParams
):
params
=
self
.
_add_guided_processor
(
params
,
guided_options
)
# Add requests to the engine.
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inputs
):
for
i
,
request_inputs
in
enumerate
(
inputs
):
...
@@ -709,9 +707,6 @@ class LLM:
...
@@ -709,9 +707,6 @@ class LLM:
f
"output:
{
0
:.
2
f
}
toks/s"
),
f
"output:
{
0
:.
2
f
}
toks/s"
),
)
)
# In the loop below, only finished outputs are used
self
.
llm_engine
.
step_return_finished_only
=
True
# Run the engine.
# Run the engine.
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
total_in_toks
=
0
total_in_toks
=
0
...
@@ -724,6 +719,7 @@ class LLM:
...
@@ -724,6 +719,7 @@ class LLM:
if
use_tqdm
:
if
use_tqdm
:
if
isinstance
(
output
,
RequestOutput
):
if
isinstance
(
output
,
RequestOutput
):
# Calculate tokens only for RequestOutput
# Calculate tokens only for RequestOutput
assert
output
.
prompt_token_ids
is
not
None
total_in_toks
+=
len
(
output
.
prompt_token_ids
)
total_in_toks
+=
len
(
output
.
prompt_token_ids
)
in_spd
=
total_in_toks
/
pbar
.
format_dict
[
"elapsed"
]
in_spd
=
total_in_toks
/
pbar
.
format_dict
[
"elapsed"
]
total_out_toks
+=
sum
(
total_out_toks
+=
sum
(
...
@@ -735,9 +731,6 @@ class LLM:
...
@@ -735,9 +731,6 @@ class LLM:
f
"output:
{
out_spd
:.
2
f
}
toks/s"
)
f
"output:
{
out_spd
:.
2
f
}
toks/s"
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
# Restore original behavior
self
.
llm_engine
.
step_return_finished_only
=
False
if
use_tqdm
:
if
use_tqdm
:
pbar
.
close
()
pbar
.
close
()
# Sort the outputs by request ID.
# Sort the outputs by request ID.
...
...
vllm/entrypoints/openai/protocol.py
View file @
551ce010
...
@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
...
@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sampling_params
import
(
LogitsProcessor
,
RequestOutputKind
,
SamplingParams
)
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
...
@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty
=
self
.
length_penalty
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
,
)
)
@
model_validator
(
mode
=
"before"
)
@
model_validator
(
mode
=
"before"
)
...
@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty
=
self
.
length_penalty
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
,
)
)
@
model_validator
(
mode
=
"before"
)
@
model_validator
(
mode
=
"before"
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
551ce010
...
@@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
def
get_chat_request_role
(
self
,
request
:
ChatCompletionRequest
)
->
str
:
def
get_chat_request_role
(
self
,
request
:
ChatCompletionRequest
)
->
str
:
if
request
.
add_generation_prompt
:
if
request
.
add_generation_prompt
:
return
self
.
response_role
return
self
.
response_role
else
:
return
request
.
messages
[
-
1
][
"role"
]
return
request
.
messages
[
-
1
][
"role"
]
async
def
chat_completion_stream_generator
(
async
def
chat_completion_stream_generator
(
self
,
self
,
...
@@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
# Send response for each token for each request.n (index)
# Send response for each token for each request.n (index)
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_texts
=
[
""
]
*
num_choices
previous_num_tokens
=
[
0
]
*
num_choices
previous_num_tokens
=
[
0
]
*
num_choices
finish_reason_sent
=
[
False
]
*
num_choices
finish_reason_sent
=
[
False
]
*
num_choices
num_prompt_tokens
=
0
tool_parser
:
Optional
[
ToolParser
]
=
self
.
tool_parser
(
tool_parser
:
Optional
[
ToolParser
]
=
self
.
tool_parser
(
tokenizer
)
if
self
.
tool_parser
else
None
tokenizer
)
if
self
.
tool_parser
else
None
if
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
tool_choice_function_name
=
request
.
tool_choice
.
function
.
name
else
:
tool_choice_function_name
=
None
# Determine whether tools are in use with "auto" tool choice
tool_choice_auto
=
(
not
tool_choice_function_name
and
self
.
_should_stream_with_auto_tool_parsing
(
request
))
all_previous_token_ids
:
Optional
[
List
[
List
[
int
]]]
if
tool_choice_auto
:
# These are only required in "auto" tool choice case
previous_texts
=
[
""
]
*
num_choices
all_previous_token_ids
=
[[]]
*
num_choices
else
:
previous_texts
,
all_previous_token_ids
=
None
,
None
try
:
try
:
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
res
.
prompt_token_ids
is
not
None
:
num_prompt_tokens
=
len
(
res
.
prompt_token_ids
)
# We need to do it here, because if there are exceptions in
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# response (by the try...catch).
...
@@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
and
request
.
stream_options
.
include_usage
):
and
request
.
stream_options
.
include_usage
):
# if continuous usage stats are requested, add it
# if continuous usage stats are requested, add it
if
request
.
stream_options
.
continuous_usage_stats
:
if
request
.
stream_options
.
continuous_usage_stats
:
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
0
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
total_tokens
=
num_
prompt_tokens
)
chunk
.
usage
=
usage
chunk
.
usage
=
usage
# otherwise don't
# otherwise don't
else
:
else
:
...
@@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
request
.
stream_options
.
include_usage
):
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
if
(
request
.
stream_options
.
continuous_usage_stats
):
continuous_usage_stats
):
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
0
,
completion_tokens
=
0
,
total_tokens
=
prompt_tokens
)
total_tokens
=
num_
prompt_tokens
)
chunk
.
usage
=
usage
chunk
.
usage
=
usage
else
:
else
:
chunk
.
usage
=
None
chunk
.
usage
=
None
...
@@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
first_iteration
=
False
first_iteration
=
False
for
output
in
res
.
outputs
:
for
output
in
res
.
outputs
:
i
=
output
.
index
i
=
output
.
index
if
finish_reason_sent
[
i
]:
if
finish_reason_sent
[
i
]:
continue
continue
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
out_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
and
request
.
top_logprobs
is
not
None
:
if
request
.
logprobs
and
request
.
top_logprobs
is
not
None
:
assert
out
_
logprobs
is
not
None
,
(
assert
out
put
.
logprobs
is
not
None
,
(
"Did not output logprobs"
)
"Did not output logprobs"
)
logprobs
=
self
.
_create_chat_logprobs
(
logprobs
=
self
.
_create_chat_logprobs
(
token_ids
=
delta_
token_ids
,
token_ids
=
output
.
token_ids
,
top_logprobs
=
out
_
logprobs
,
top_logprobs
=
out
put
.
logprobs
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
num_output_top_logprobs
=
request
.
top_logprobs
,
num_output_top_logprobs
=
request
.
top_logprobs
,
)
)
else
:
else
:
logprobs
=
None
logprobs
=
None
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_text
=
output
.
text
delta_message
:
Optional
[
DeltaMessage
]
=
None
delta_message
:
Optional
[
DeltaMessage
]
# handle streaming deltas for tools with named tool_choice
# handle streaming deltas for tools with named tool_choice
if
(
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
if
tool_choice_function_name
:
ChatCompletionNamedToolChoiceParam
):
delta_message
=
DeltaMessage
(
tool_calls
=
[
delta_message
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
function
=
DeltaFunctionCall
(
DeltaToolCall
(
function
=
DeltaFunctionCall
(
name
=
request
.
tool_choice
.
function
.
name
,
name
=
tool_choice
_
function
_
name
,
arguments
=
delta_text
),
arguments
=
delta_text
),
index
=
i
)
index
=
i
)
])
])
# handle streaming deltas for tools with "auto" tool choice
# handle streaming deltas for tools with "auto" tool choice
elif
(
self
.
_should_stream_with_auto_tool_parsing
(
request
)
elif
tool_choice_auto
:
and
tool_parser
):
assert
previous_texts
is
not
None
assert
all_previous_token_ids
is
not
None
assert
tool_parser
is
not
None
#TODO optimize manipulation of these lists
previous_text
=
previous_texts
[
i
]
previous_token_ids
=
all_previous_token_ids
[
i
]
current_text
=
previous_text
+
delta_text
current_token_ids
=
previous_token_ids
+
list
(
output
.
token_ids
)
delta_message
=
(
delta_message
=
(
tool_parser
.
extract_tool_calls_streaming
(
tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
s
[
i
]
,
previous_text
=
previous_text
,
current_text
=
output
.
text
,
current_text
=
current_
text
,
delta_text
=
delta_text
,
delta_text
=
delta_text
,
previous_token_ids
=
\
previous_token_ids
=
previous_token_ids
,
output
.
token_ids
[
current_token_ids
=
current_token_ids
,
:
-
1
*
len
(
delta_token_ids
)
delta_token_ids
=
output
.
token_ids
))
],
current_token_ids
=
output
.
token_ids
,
# update the previous values for the next iteration
delta_token_ids
=
delta_token_ids
previous_texts
[
i
]
=
current_text
)
all_previous_token_ids
[
i
]
=
current_token_ids
)
# handle streaming just a content delta
# handle streaming just a content delta
else
:
else
:
delta_message
=
DeltaMessage
(
content
=
delta_text
)
delta_message
=
DeltaMessage
(
content
=
delta_text
)
# set the previous values for the next iteration
# set the previous values for the next iteration
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
+=
len
(
output
.
token_ids
)
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
# if the message delta is None (e.g. because it was a
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
# "control token" for tool calls or the parser otherwise
...
@@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
# handle usage stats if requested & if continuous
# handle usage stats if requested & if continuous
if
(
request
.
stream_options
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
):
if
request
.
stream_options
.
continuous_usage_stats
:
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
completion_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
total_tokens
=
num_
prompt_tokens
+
completion_tokens
,
completion_tokens
,
)
)
chunk
.
usage
=
usage
chunk
.
usage
=
usage
...
@@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser
.
prev_tool_call_arr
[
index
].
get
(
tool_parser
.
prev_tool_call_arr
[
index
].
get
(
"arguments"
,
{}))
"arguments"
,
{}))
# get what we've streamed so f
o
r for arguments
# get what we've streamed so f
a
r for arguments
# for the current tool
# for the current tool
actual_call
=
tool_parser
.
streamed_args_for_tool
[
actual_call
=
tool_parser
.
streamed_args_for_tool
[
index
]
index
]
...
@@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
])
])
# Send the finish response for each request.n only once
# Send the finish response for each request.n only once
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
choice_data
=
ChatCompletionResponseStreamChoice
(
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
index
=
i
,
delta
=
delta_message
,
delta
=
delta_message
,
...
@@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
model
=
model_name
)
model
=
model_name
)
if
(
request
.
stream_options
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
):
if
request
.
stream_options
.
continuous_usage_stats
:
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
completion_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
total_tokens
=
num_
prompt_tokens
+
completion_tokens
,
completion_tokens
,
)
)
chunk
.
usage
=
usage
chunk
.
usage
=
usage
...
@@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage
# is sent, send the usage
if
(
request
.
stream_options
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
and
request
.
stream_options
.
include_usage
):
completion_tokens
=
previous_num_tokens
[
i
]
final_usage
=
UsageInfo
(
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
num_
prompt_tokens
,
completion_tokens
=
previous_num
_tokens
[
i
]
,
completion_tokens
=
completion
_tokens
,
total_tokens
=
prompt_tokens
+
previous_num
_tokens
[
i
]
,
total_tokens
=
num_
prompt_tokens
+
completion
_tokens
,
)
)
final_usage_chunk
=
ChatCompletionStreamResponse
(
final_usage_chunk
=
ChatCompletionStreamResponse
(
...
@@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
or
""
)
or
""
)
choice
.
message
.
content
=
full_message
choice
.
message
.
content
=
full_message
assert
final_res
.
prompt_token_ids
is
not
None
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
num_generated_tokens
=
sum
(
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
...
@@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
return
bool
(
return
bool
(
# if there is a delta message that includes tool calls which
# if there is a delta message that includes tool calls which
# include a function that has arguments
# include a function that has arguments
self
.
enable_auto_tools
and
self
.
tool_parser
and
delta_message
output
.
finish_reason
is
not
None
and
self
.
enable_auto_tools
and
self
.
tool_parser
and
delta_message
and
delta_message
.
tool_calls
and
delta_message
.
tool_calls
[
0
]
and
delta_message
.
tool_calls
and
delta_message
.
tool_calls
[
0
]
and
delta_message
.
tool_calls
[
0
].
function
and
delta_message
.
tool_calls
[
0
].
function
and
delta_message
.
tool_calls
[
0
].
function
.
arguments
is
not
None
and
delta_message
.
tool_calls
[
0
].
function
.
arguments
is
not
None
and
output
.
finish_reason
is
not
None
)
)
vllm/entrypoints/openai/serving_completion.py
View file @
551ce010
...
@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_texts
=
[
""
]
*
num_choices
*
num_prompts
previous_text
_len
s
=
[
0
]
*
num_choices
*
num_prompts
previous_num_tokens
=
[
0
]
*
num_choices
*
num_prompts
previous_num_tokens
=
[
0
]
*
num_choices
*
num_prompts
has_echoed
=
[
False
]
*
num_choices
*
num_prompts
has_echoed
=
[
False
]
*
num_choices
*
num_prompts
num_prompt_tokens
=
[
0
]
*
num_prompts
try
:
try
:
async
for
prompt_idx
,
res
in
result_generator
:
async
for
prompt_idx
,
res
in
result_generator
:
...
@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs
=
res
.
prompt_logprobs
prompt_logprobs
=
res
.
prompt_logprobs
prompt_text
=
res
.
prompt
prompt_text
=
res
.
prompt
# Prompt details are excluded from later streamed outputs
if
res
.
prompt_token_ids
is
not
None
:
num_prompt_tokens
[
prompt_idx
]
=
len
(
res
.
prompt_token_ids
)
delta_token_ids
:
GenericSequence
[
int
]
delta_token_ids
:
GenericSequence
[
int
]
out_logprobs
:
Optional
[
GenericSequence
[
Optional
[
Dict
[
out_logprobs
:
Optional
[
GenericSequence
[
Optional
[
Dict
[
int
,
Logprob
]]]]
int
,
Logprob
]]]]
...
@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert
request
.
max_tokens
is
not
None
assert
request
.
max_tokens
is
not
None
if
request
.
echo
and
request
.
max_tokens
==
0
:
if
request
.
echo
and
request
.
max_tokens
==
0
:
assert
prompt_token_ids
is
not
None
assert
prompt_text
is
not
None
assert
prompt_text
is
not
None
# only return the prompt
# only return the prompt
delta_text
=
prompt_text
delta_text
=
prompt_text
...
@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed
[
i
]
=
True
has_echoed
[
i
]
=
True
elif
(
request
.
echo
and
request
.
max_tokens
>
0
elif
(
request
.
echo
and
request
.
max_tokens
>
0
and
not
has_echoed
[
i
]):
and
not
has_echoed
[
i
]):
assert
prompt_token_ids
is
not
None
assert
prompt_text
is
not
None
assert
prompt_text
is
not
None
assert
prompt_logprobs
is
not
None
assert
prompt_logprobs
is
not
None
# echo the prompt and first token
# echo the prompt and first token
...
@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed
[
i
]
=
True
has_echoed
[
i
]
=
True
else
:
else
:
# return just the delta
# return just the delta
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_text
=
output
.
text
delta_token_ids
=
output
.
token_ids
[
delta_token_ids
=
output
.
token_ids
previous_num_tokens
[
i
]:]
out_logprobs
=
output
.
logprobs
out_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
if
request
.
logprobs
is
not
None
:
assert
out_logprobs
is
not
None
,
(
assert
out_logprobs
is
not
None
,
(
...
@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs
=
out_logprobs
,
top_logprobs
=
out_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
initial_text_offset
=
len
(
previous_texts
[
i
]
)
,
initial_text_offset
=
previous_text
_len
s
[
i
],
)
)
else
:
else
:
logprobs
=
None
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_text
_len
s
[
i
]
+
=
len
(
output
.
text
)
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
previous_num_tokens
[
i
]
+
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
finish_reason
=
output
.
finish_reason
stop_reason
=
output
.
stop_reason
stop_reason
=
output
.
stop_reason
...
@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
and
request
.
stream_options
.
include_usage
):
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
if
(
request
.
stream_options
.
continuous_usage_stats
or
output
.
finish_reason
is
not
None
):
or
output
.
finish_reason
is
not
None
):
prompt_tokens
=
len
(
prompt_token_id
s
)
prompt_tokens
=
num_
prompt_token
s
[
prompt
_id
x
]
completion_tokens
=
len
(
output
.
token_ids
)
completion_tokens
=
previous_num_tokens
[
i
]
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
completion_tokens
=
completion_tokens
,
...
@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
for
final_res
in
final_res_batch
:
for
final_res
in
final_res_batch
:
prompt_token_ids
=
final_res
.
prompt_token_ids
prompt_token_ids
=
final_res
.
prompt_token_ids
assert
prompt_token_ids
is
not
None
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
prompt_text
=
final_res
.
prompt
...
@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
)
)
choices
.
append
(
choice_data
)
choices
.
append
(
choice_data
)
num_generated_tokens
+=
len
(
output
.
token_ids
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
num_generated_tokens
+=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
prompt_tokens
=
num_prompt_tokens
,
...
...
vllm/outputs.py
View file @
551ce010
...
@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
...
@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from
typing
import
Union
from
typing
import
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
SequenceGroup
,
SequenceStatus
)
SequenceGroup
,
SequenceStatus
)
...
@@ -92,7 +93,7 @@ class RequestOutput:
...
@@ -92,7 +93,7 @@ class RequestOutput:
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
Optional
[
List
[
int
]
]
,
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
finished
:
bool
,
finished
:
bool
,
...
@@ -113,19 +114,26 @@ class RequestOutput:
...
@@ -113,19 +114,26 @@ class RequestOutput:
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
@
classmethod
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
def
from_seq_group
(
cls
,
if
seq_group
.
sampling_params
is
None
:
seq_group
:
SequenceGroup
)
->
Optional
[
"RequestOutput"
]:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
is
None
:
raise
ValueError
(
raise
ValueError
(
"Sampling parameters are missing for a CompletionRequest."
)
"Sampling parameters are missing for a CompletionRequest."
)
finished
=
seq_group
.
is_finished
()
if
sampling_params
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
and
(
not
finished
):
return
None
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
if
len
(
seqs
)
==
1
:
if
len
(
seqs
)
==
1
:
top_n_seqs
=
seqs
top_n_seqs
=
seqs
else
:
else
:
# Get the top-n sequences.
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
n
=
sampling_params
.
n
if
seq_group
.
sampling_params
.
use_beam_search
:
if
sampling_params
.
use_beam_search
:
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
seq_group
.
sampling_params
.
length_penalty
)
sampling_params
.
length_penalty
)
else
:
else
:
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
...
@@ -135,26 +143,49 @@ class RequestOutput:
...
@@ -135,26 +143,49 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
# logprobs are not requested.
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
include_logprobs
=
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
text_buffer_length
=
sampling_params
.
output_text_buffer_length
outputs
=
[
delta
=
sampling_params
.
output_kind
==
RequestOutputKind
.
DELTA
CompletionOutput
(
seqs
.
index
(
seq
),
outputs
=
[]
seq
.
get_output_text_to_return
(
text_buffer_length
),
include_prompt
=
True
seq
.
data
.
_output_token_ids
,
for
seq
in
top_n_seqs
:
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
output_text
=
seq
.
get_output_text_to_return
(
seq
.
output_logprobs
if
include_logprobs
else
None
,
text_buffer_length
,
delta
)
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
output_token_ids
=
seq
.
get_output_token_ids_to_return
(
delta
)
seq
.
stop_reason
)
for
seq
in
top_n_seqs
output_logprobs
=
seq
.
output_logprobs
if
include_logprobs
else
None
]
if
delta
:
# Slice logprobs delta if applicable
if
output_logprobs
:
output_logprobs
=
output_logprobs
[
-
len
(
output_token_ids
):]
# Don't include prompt if this is after the first output
# containing decode token ids
if
include_prompt
and
seq
.
get_output_len
()
>
len
(
output_token_ids
):
include_prompt
=
False
outputs
.
append
(
CompletionOutput
(
seqs
.
index
(
seq
),
output_text
,
output_token_ids
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
output_logprobs
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
))
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
prompt
=
seq_group
.
prompt
if
include_prompt
:
prompt_token_ids
=
seq_group
.
prompt_token_ids
prompt
=
seq_group
.
prompt
encoder_prompt
=
seq_group
.
encoder_prompt
prompt_token_ids
=
seq_group
.
prompt_token_ids
encoder_prompt_token_ids
=
seq_group
.
encoder_prompt_token_ids
encoder_prompt
=
seq_group
.
encoder_prompt
prompt_logprobs
=
seq_group
.
prompt_logprobs
encoder_prompt_token_ids
=
seq_group
.
encoder_prompt_token_ids
finished
=
seq_group
.
is_finished
()
prompt_logprobs
=
seq_group
.
prompt_logprobs
else
:
prompt
=
None
prompt_token_ids
=
None
encoder_prompt
=
None
encoder_prompt_token_ids
=
None
prompt_logprobs
=
None
finished_time
=
time
.
time
()
if
finished
else
None
finished_time
=
time
.
time
()
if
finished
else
None
seq_group
.
set_finished_time
(
finished_time
)
seq_group
.
set_finished_time
(
finished_time
)
return
cls
(
seq_group
.
request_id
,
return
cls
(
seq_group
.
request_id
,
...
...
vllm/sampling_params.py
View file @
551ce010
"""Sampling parameters for text generation."""
"""Sampling parameters for text generation."""
import
copy
import
copy
from
enum
import
IntEnum
from
enum
import
Enum
,
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Union
...
@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
...
@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
to sample from."""
to sample from."""
class
RequestOutputKind
(
Enum
):
# Return entire output so far in every RequestOutput
CUMULATIVE
=
0
# Return only deltas in each RequestOutput
DELTA
=
1
# Do not return intermediate RequestOuputs
FINAL_ONLY
=
2
class
SamplingParams
(
class
SamplingParams
(
msgspec
.
Struct
,
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
...
@@ -147,6 +156,7 @@ class SamplingParams(
...
@@ -147,6 +156,7 @@ class SamplingParams(
logits_processors
:
Optional
[
Any
]
=
None
logits_processors
:
Optional
[
Any
]
=
None
include_stop_str_in_output
:
bool
=
False
include_stop_str_in_output
:
bool
=
False
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
# The below fields are not supposed to be used as an input.
# The below fields are not supposed to be used as an input.
# They are set in post_init.
# They are set in post_init.
...
@@ -182,6 +192,7 @@ class SamplingParams(
...
@@ -182,6 +192,7 @@ class SamplingParams(
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
,
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
,
)
->
"SamplingParams"
:
)
->
"SamplingParams"
:
return
SamplingParams
(
return
SamplingParams
(
n
=
1
if
n
is
None
else
n
,
n
=
1
if
n
is
None
else
n
,
...
@@ -213,6 +224,7 @@ class SamplingParams(
...
@@ -213,6 +224,7 @@ class SamplingParams(
spaces_between_special_tokens
=
spaces_between_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
logits_processors
=
logits_processors
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
output_kind
=
output_kind
,
)
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
...
@@ -317,6 +329,9 @@ class SamplingParams(
...
@@ -317,6 +329,9 @@ class SamplingParams(
raise
ValueError
(
raise
ValueError
(
"stop strings are only supported when detokenize is True. "
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
)
"Set detokenize=True to use stop."
)
if
self
.
best_of
!=
self
.
n
and
self
.
output_kind
==
(
RequestOutputKind
.
DELTA
):
raise
ValueError
(
"best_of must equal n to use output_kind=DELTA"
)
def
_verify_beam_search
(
self
)
->
None
:
def
_verify_beam_search
(
self
)
->
None
:
if
self
.
best_of
==
1
:
if
self
.
best_of
==
1
:
...
...
vllm/sequence.py
View file @
551ce010
...
@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
...
@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from
array
import
array
from
array
import
array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
Optional
,
Set
,
Tuple
,
Union
,
cast
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
,
cast
import
msgspec
import
msgspec
import
torch
import
torch
...
@@ -407,6 +408,10 @@ class Sequence:
...
@@ -407,6 +408,10 @@ class Sequence:
self
.
status
=
SequenceStatus
.
WAITING
self
.
status
=
SequenceStatus
.
WAITING
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
# These are used to keep track of delta outputs
self
.
_last_token_ids_offset
:
int
=
0
self
.
_last_output_text_offset
:
int
=
0
# Used for incremental detokenization
# Used for incremental detokenization
self
.
prefix_offset
=
0
self
.
prefix_offset
=
0
self
.
read_offset
=
0
self
.
read_offset
=
0
...
@@ -462,11 +467,35 @@ class Sequence:
...
@@ -462,11 +467,35 @@ class Sequence:
return
self
.
prompt_adapter_request
.
prompt_adapter_id
\
return
self
.
prompt_adapter_request
.
prompt_adapter_id
\
if
self
.
prompt_adapter_request
else
0
if
self
.
prompt_adapter_request
else
0
def
get_output_text_to_return
(
self
,
buffer_length
:
int
):
def
get_output_text_to_return
(
self
,
buffer_length
:
int
,
delta
:
bool
)
->
str
:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
# We return the full output text if the sequence is finished.
truncate
=
buffer_length
and
not
self
.
is_finished
()
truncate
=
buffer_length
and
not
self
.
is_finished
()
return
self
.
output_text
[:
-
buffer_length
]
if
truncate
else
(
if
not
delta
:
self
.
output_text
)
return
self
.
output_text
[:
-
buffer_length
]
if
truncate
else
(
self
.
output_text
)
length
=
len
(
self
.
output_text
)
-
buffer_length
last_offset
=
self
.
_last_output_text_offset
if
last_offset
<
length
:
self
.
_last_output_text_offset
=
length
return
self
.
output_text
[
last_offset
:
length
]
return
""
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
GenericSequence
[
int
]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
get_output_token_ids
()
length
=
self
.
get_output_len
()
last_offset
=
self
.
_last_token_ids_offset
if
last_offset
<
length
:
self
.
_last_token_ids_offset
=
length
return
self
.
data
.
_output_token_ids
[
last_offset
:]
return
()
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
# TODO This can produce incorrect hash when block size > prompt size
# TODO This can produce incorrect hash when block size > prompt size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment