Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
22de4523
Unverified
Commit
22de4523
authored
Mar 04, 2024
by
Antoni Baum
Committed by
GitHub
Mar 04, 2024
Browse files
Push logprob generation to LLMEngine (#3065)
Co-authored-by:
Avnish Narayan
<
avnish@anyscale.com
>
parent
76e8a704
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
555 additions
and
335 deletions
+555
-335
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+58
-3
tests/samplers/test_logprobs.py
tests/samplers/test_logprobs.py
+36
-6
tests/worker/spec_decode/utils.py
tests/worker/spec_decode/utils.py
+7
-5
vllm/config.py
vllm/config.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+9
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+24
-5
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+40
-2
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+128
-108
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+203
-188
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+19
-4
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+11
-4
vllm/sequence.py
vllm/sequence.py
+17
-8
vllm/worker/spec_decode/multi_step_worker.py
vllm/worker/spec_decode/multi_step_worker.py
+1
-1
No files found.
tests/entrypoints/test_openai_server.py
View file @
22de4523
...
@@ -213,14 +213,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
...
@@ -213,14 +213,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
10
,
max_tokens
=
10
,
logprobs
=
True
,
logprobs
=
True
,
top_logprobs
=
10
)
top_logprobs
=
5
)
assert
chat_completion
.
id
is
not
None
assert
chat_completion
.
id
is
not
None
assert
chat_completion
.
choices
is
not
None
and
len
(
assert
chat_completion
.
choices
is
not
None
and
len
(
chat_completion
.
choices
)
==
1
chat_completion
.
choices
)
==
1
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
message
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
is
not
None
assert
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
is
not
None
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
])
==
10
assert
len
(
chat_completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
])
==
5
message
=
chat_completion
.
choices
[
0
].
message
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
assert
message
.
role
==
"assistant"
...
@@ -229,7 +229,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
...
@@ -229,7 +229,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
# test multi-turn dialogue
# test multi-turn dialogue
messages
.
append
({
"role"
:
"user"
,
"content"
:
"express your result in json"
})
messages
.
append
({
"role"
:
"user"
,
"content"
:
"express your result in json"
})
chat_completion
=
await
client
.
chat
.
completions
.
create
(
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
model_name
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
10
,
max_tokens
=
10
,
)
)
...
@@ -237,6 +237,61 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
...
@@ -237,6 +237,61 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_too_many_logprobs
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}]
# Default max_logprobs is 5, so this should raise an error
with
pytest
.
raises
((
openai
.
BadRequestError
,
openai
.
APIError
)):
stream
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_tokens
=
10
,
logprobs
=
True
,
top_logprobs
=
10
,
stream
=
True
)
async
for
chunk
in
stream
:
...
with
pytest
.
raises
(
openai
.
BadRequestError
):
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_tokens
=
10
,
logprobs
=
True
,
top_logprobs
=
10
,
stream
=
False
)
with
pytest
.
raises
((
openai
.
BadRequestError
,
openai
.
APIError
)):
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Test"
,
max_tokens
=
10
,
logprobs
=
10
,
stream
=
True
)
async
for
chunk
in
stream
:
...
with
pytest
.
raises
(
openai
.
BadRequestError
):
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Test"
,
max_tokens
=
10
,
logprobs
=
10
,
stream
=
False
)
# the server should still work afterwards
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
max_tokens
=
10
,
stream
=
False
)
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
# just test 1 lora hereafter
# just test 1 lora hereafter
"model_name"
,
"model_name"
,
...
...
tests/samplers/test_logprobs.py
View file @
22de4523
import
pytest
import
pytest
import
torch
import
torch
from
tests.conftest
import
VllmRunner
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
...
@@ -16,6 +17,7 @@ def test_get_prompt_logprobs(
...
@@ -16,6 +17,7 @@ def test_get_prompt_logprobs(
example_prompts
,
example_prompts
,
):
):
max_tokens
=
5
max_tokens
=
5
num_top_logprobs
=
6
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
hf_logprobs
=
hf_model
.
generate_greedy_logprobs
(
example_prompts
,
example_prompts
,
...
@@ -23,19 +25,32 @@ def test_get_prompt_logprobs(
...
@@ -23,19 +25,32 @@ def test_get_prompt_logprobs(
)
)
del
hf_model
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
max_logprobs
=
num_top_logprobs
)
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
vllm_sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
logprobs
=
5
,
logprobs
=
num_top_logprobs
,
prompt_logprobs
=
5
,
prompt_logprobs
=
5
,
temperature
=
0.0
)
temperature
=
0.0
)
vllm_results
=
vllm_model
.
model
.
generate
(
vllm_results
=
vllm_model
.
model
.
generate
(
example_prompts
,
sampling_params
=
vllm_sampling_params
)
example_prompts
,
sampling_params
=
vllm_sampling_params
)
del
vllm_model
# Test whether logprobs are included in the results.
# Test whether logprobs are included in the results.
for
result
in
vllm_results
:
for
result
in
vllm_results
:
assert
result
.
prompt_logprobs
is
not
None
assert
result
.
prompt_logprobs
is
not
None
assert
result
.
outputs
[
0
].
logprobs
is
not
None
assert
result
.
outputs
[
0
].
logprobs
is
not
None
assert
len
(
result
.
outputs
[
0
].
logprobs
)
==
max_tokens
for
logprobs
in
result
.
outputs
[
0
].
logprobs
:
assert
len
(
logprobs
)
==
num_top_logprobs
output_text
=
result
.
outputs
[
0
].
text
output_string_from_most_likely_tokens
=
[]
for
top_logprobs
in
result
.
outputs
[
0
].
logprobs
:
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
output_string_from_most_likely_tokens
.
append
(
top_logprob
.
decoded_token
)
output_string_from_most_likely_tokens
=
""
.
join
(
output_string_from_most_likely_tokens
)
assert
output_text
==
output_string_from_most_likely_tokens
,
(
"The output text from the top logprob for each token position "
"should be the same as the output text in the result."
)
# Test whether prompt logprobs are consistent with HF
# Test whether prompt logprobs are consistent with HF
for
vllm_result
,
hf_logprob
in
zip
(
vllm_results
,
hf_logprobs
):
for
vllm_result
,
hf_logprob
in
zip
(
vllm_results
,
hf_logprobs
):
...
@@ -43,14 +58,29 @@ def test_get_prompt_logprobs(
...
@@ -43,14 +58,29 @@ def test_get_prompt_logprobs(
vllm_prompt_logprobs
=
vllm_result
.
prompt_logprobs
[
1
:]
vllm_prompt_logprobs
=
vllm_result
.
prompt_logprobs
[
1
:]
for
i
,
vllm_prompt_logprob_dict
in
enumerate
(
vllm_prompt_logprobs
):
for
i
,
vllm_prompt_logprob_dict
in
enumerate
(
vllm_prompt_logprobs
):
for
token_id
,
logprob
in
vllm_prompt_logprob_dict
.
items
():
for
token_id
,
logprob
in
vllm_prompt_logprob_dict
.
items
():
torch
.
testing
.
assert_close
(
logprob
,
torch
.
testing
.
assert_close
(
logprob
.
logprob
,
hf_logprob
[
0
][
i
][
token_id
].
item
(),
hf_logprob
[
0
][
i
][
token_id
].
item
(),
atol
=
1e-2
,
atol
=
1e-2
,
rtol
=
1e-2
)
rtol
=
1e-2
)
vllm_sample_logprobs
=
vllm_result
.
outputs
[
0
].
logprobs
vllm_sample_logprobs
=
vllm_result
.
outputs
[
0
].
logprobs
for
i
,
vllm_sample_logprob_dict
in
enumerate
(
vllm_sample_logprobs
):
for
i
,
top_logprobs
in
enumerate
(
vllm_sample_logprobs
):
for
token_id
,
logprob
in
vllm_sample_logprob_dict
.
items
():
for
token_id
,
sample_logprob
in
top_logprobs
.
items
():
logprob
=
sample_logprob
.
logprob
torch
.
testing
.
assert_close
(
logprob
,
torch
.
testing
.
assert_close
(
logprob
,
hf_logprob
[
i
][
-
1
][
token_id
].
item
(),
hf_logprob
[
i
][
-
1
][
token_id
].
item
(),
atol
=
1e-2
,
atol
=
1e-2
,
rtol
=
1e-2
)
rtol
=
1e-2
)
assert
isinstance
(
sample_logprob
.
decoded_token
,
str
),
\
(
"The token should be decoded by the time it is returned "
" to the user."
)
def
test_max_logprobs
():
runner
=
VllmRunner
(
"facebook/opt-125m"
,
max_logprobs
=
1
)
vllm_sampling_params
=
SamplingParams
(
logprobs
=
1
)
# should pass
runner
.
generate
([
"Hello world"
],
sampling_params
=
vllm_sampling_params
)
bad_sampling_params
=
SamplingParams
(
logprobs
=
2
)
with
pytest
.
raises
(
ValueError
):
runner
.
generate
([
"Hello world"
],
sampling_params
=
bad_sampling_params
)
tests/worker/spec_decode/utils.py
View file @
22de4523
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Dict
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Dict
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
SequenceGroupMetadata
,
SequenceData
from
vllm.sequence
import
Logprob
,
SequenceGroupMetadata
,
SequenceData
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
...
@@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts(
...
@@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts(
def
assert_logprobs_dict_allclose
(
def
assert_logprobs_dict_allclose
(
actual_logprobs
:
List
[
Dict
[
int
,
float
]],
actual_logprobs
:
List
[
Dict
[
int
,
Logprob
]],
expected_logprobs
:
List
[
Dict
[
int
,
float
]])
->
None
:
expected_logprobs
:
List
[
Dict
[
int
,
Logprob
]])
->
None
:
for
single_step_actual_logprobs
,
single_step_expected_logprobs
in
zip
(
for
single_step_actual_logprobs
,
single_step_expected_logprobs
in
zip
(
actual_logprobs
,
expected_logprobs
):
actual_logprobs
,
expected_logprobs
):
assert
set
(
single_step_actual_logprobs
.
keys
())
==
set
(
assert
set
(
single_step_actual_logprobs
.
keys
())
==
set
(
single_step_expected_logprobs
.
keys
())
single_step_expected_logprobs
.
keys
())
for
token_id
in
single_step_actual_logprobs
:
for
token_id
in
single_step_actual_logprobs
:
actual
=
torch
.
tensor
(
single_step_actual_logprobs
[
token_id
])
actual
=
torch
.
tensor
(
expected
=
torch
.
tensor
(
single_step_expected_logprobs
[
token_id
])
single_step_actual_logprobs
[
token_id
].
logprob
)
expected
=
torch
.
tensor
(
single_step_expected_logprobs
[
token_id
].
logprob
)
assert
torch
.
allclose
(
actual
,
expected
)
assert
torch
.
allclose
(
actual
,
expected
)
vllm/config.py
View file @
22de4523
...
@@ -79,6 +79,7 @@ class ModelConfig:
...
@@ -79,6 +79,7 @@ class ModelConfig:
quantization
:
Optional
[
str
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
enforce_eager
:
bool
=
False
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
5
,
)
->
None
:
)
->
None
:
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -93,6 +94,7 @@ class ModelConfig:
...
@@ -93,6 +94,7 @@ class ModelConfig:
self
.
quantization
=
quantization
self
.
quantization
=
quantization
self
.
enforce_eager
=
enforce_eager
self
.
enforce_eager
=
enforce_eager
self
.
max_context_len_to_capture
=
max_context_len_to_capture
self
.
max_context_len_to_capture
=
max_context_len_to_capture
self
.
max_logprobs
=
max_logprobs
if
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
:
if
os
.
environ
.
get
(
"VLLM_USE_MODELSCOPE"
,
"False"
).
lower
()
==
"true"
:
# download model from ModelScope hub,
# download model from ModelScope hub,
...
...
vllm/engine/arg_utils.py
View file @
22de4523
...
@@ -31,6 +31,7 @@ class EngineArgs:
...
@@ -31,6 +31,7 @@ class EngineArgs:
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_seqs
:
int
=
256
max_num_seqs
:
int
=
256
max_paddings
:
int
=
256
max_paddings
:
int
=
256
max_logprobs
:
int
=
5
# OpenAI default value
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
code_revision
:
Optional
[
str
]
=
None
code_revision
:
Optional
[
str
]
=
None
...
@@ -212,6 +213,12 @@ class EngineArgs:
...
@@ -212,6 +213,12 @@ class EngineArgs:
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
max_paddings
,
default
=
EngineArgs
.
max_paddings
,
help
=
'maximum number of paddings in a batch'
)
help
=
'maximum number of paddings in a batch'
)
parser
.
add_argument
(
'--max-logprobs'
,
type
=
int
,
default
=
EngineArgs
.
max_logprobs
,
help
=
(
'max number of log probs to return logprobs is specified in'
' SamplingParams'
))
parser
.
add_argument
(
'--disable-log-stats'
,
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
help
=
'disable logging statistics'
)
...
@@ -300,7 +307,8 @@ class EngineArgs:
...
@@ -300,7 +307,8 @@ class EngineArgs:
self
.
trust_remote_code
,
self
.
download_dir
,
self
.
load_format
,
self
.
trust_remote_code
,
self
.
download_dir
,
self
.
load_format
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
code_revision
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
quantization
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
quantization
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
)
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
max_logprobs
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
...
...
vllm/engine/async_llm_engine.py
View file @
22de4523
...
@@ -47,7 +47,7 @@ class AsyncStream:
...
@@ -47,7 +47,7 @@ class AsyncStream:
self
.
_queue
=
asyncio
.
Queue
()
self
.
_queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
self
.
_finished
=
False
def
put
(
self
,
item
:
RequestOutput
)
->
None
:
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Exception
]
)
->
None
:
if
self
.
_finished
:
if
self
.
_finished
:
return
return
self
.
_queue
.
put_nowait
(
item
)
self
.
_queue
.
put_nowait
(
item
)
...
@@ -110,6 +110,17 @@ class RequestTracker:
...
@@ -110,6 +110,17 @@ class RequestTracker:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
self
.
abort_request
(
request_id
)
self
.
abort_request
(
request_id
)
def
process_exception
(
self
,
request_id
:
str
,
exception
:
Exception
,
*
,
verbose
:
bool
=
False
)
->
None
:
"""Propagate an exception from the engine."""
self
.
_request_streams
[
request_id
].
put
(
exception
)
if
verbose
:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
self
.
abort_request
(
request_id
)
def
add_request
(
self
,
request_id
:
str
,
def
add_request
(
self
,
request_id
:
str
,
**
engine_add_request_kwargs
)
->
AsyncStream
:
**
engine_add_request_kwargs
)
->
AsyncStream
:
"""Add a request to be sent to the engine on the next background
"""Add a request to be sent to the engine on the next background
...
@@ -377,10 +388,18 @@ class AsyncLLMEngine:
...
@@ -377,10 +388,18 @@ class AsyncLLMEngine:
for
new_request
in
new_requests
:
for
new_request
in
new_requests
:
# Add the request into the vLLM engine's waiting queue.
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
# TODO: Maybe add add_request_batch to reduce Ray overhead
if
self
.
engine_use_ray
:
try
:
await
self
.
engine
.
add_request
.
remote
(
**
new_request
)
if
self
.
engine_use_ray
:
else
:
await
self
.
engine
.
add_request
.
remote
(
**
new_request
)
await
self
.
engine
.
add_request_async
(
**
new_request
)
else
:
await
self
.
engine
.
add_request_async
(
**
new_request
)
except
ValueError
as
e
:
# TODO: use a vLLM specific error for failed validation
self
.
_request_tracker
.
process_exception
(
new_request
[
"request_id"
],
e
,
verbose
=
self
.
log_requests
,
)
if
finished_requests
:
if
finished_requests
:
await
self
.
_engine_abort
(
finished_requests
)
await
self
.
_engine_abort
(
finished_requests
)
...
...
vllm/engine/llm_engine.py
View file @
22de4523
...
@@ -18,7 +18,7 @@ from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
...
@@ -18,7 +18,7 @@ from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
TokenizerGroup
)
TokenizerGroup
)
...
@@ -473,6 +473,13 @@ class LLMEngine:
...
@@ -473,6 +473,13 @@ class LLMEngine:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
max_logprobs
=
self
.
get_model_config
().
max_logprobs
if
(
sampling_params
.
logprobs
and
sampling_params
.
logprobs
>
max_logprobs
)
or
(
sampling_params
.
prompt_logprobs
and
sampling_params
.
prompt_logprobs
>
max_logprobs
):
raise
ValueError
(
f
"Cannot request more than "
f
"
{
max_logprobs
}
logprobs."
)
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
monotonic
()
arrival_time
=
time
.
monotonic
()
prompt_token_ids
=
self
.
encode_request
(
prompt_token_ids
=
self
.
encode_request
(
...
@@ -583,6 +590,13 @@ class LLMEngine:
...
@@ -583,6 +590,13 @@ class LLMEngine:
# Process prompt logprobs
# Process prompt logprobs
prompt_logprobs
=
outputs
.
prompt_logprobs
prompt_logprobs
=
outputs
.
prompt_logprobs
if
prompt_logprobs
is
not
None
:
if
prompt_logprobs
is
not
None
:
# We can pick any sequence for the prompt.
seq
=
next
(
iter
(
seq_group
.
seqs_dict
.
values
()))
all_token_ids
=
seq
.
get_token_ids
()
for
i
,
prompt_logprobs_for_token
in
enumerate
(
prompt_logprobs
):
self
.
_decode_logprobs
(
seq
,
seq_group
.
sampling_params
,
prompt_logprobs_for_token
,
all_token_ids
[:
i
])
seq_group
.
prompt_logprobs
=
prompt_logprobs
seq_group
.
prompt_logprobs
=
prompt_logprobs
# Process samples
# Process samples
...
@@ -930,12 +944,36 @@ class LLMEngine:
...
@@ -930,12 +944,36 @@ class LLMEngine:
time_e2e_requests
=
time_e2e_requests
,
time_e2e_requests
=
time_e2e_requests
,
)
)
def
_decode_logprobs
(
self
,
seq
:
Sequence
,
prms
:
SamplingParams
,
logprobs
:
Dict
[
int
,
Logprob
],
all_input_ids
:
List
[
int
])
->
None
:
if
not
logprobs
:
return
for
token_id
,
sample_logprob
in
logprobs
.
items
():
if
(
sample_logprob
.
decoded_token
is
None
and
token_id
!=
-
1
):
all_input_ids_with_logprob
=
all_input_ids
[:
-
1
]
+
[
token_id
]
_
,
new_text
,
prefix_offset
,
read_offset
=
detokenize_incrementally
(
self
.
get_tokenizer_for_seq
(
seq
),
all_input_ids
=
all_input_ids_with_logprob
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
sample_logprob
.
decoded_token
=
new_text
def
_decode_sequence
(
self
,
seq
:
Sequence
,
prms
:
SamplingParams
)
->
None
:
def
_decode_sequence
(
self
,
seq
:
Sequence
,
prms
:
SamplingParams
)
->
None
:
"""Decodes the new token for a sequence."""
"""Decodes the new token for a sequence."""
all_input_ids
=
seq
.
get_token_ids
()
self
.
_decode_logprobs
(
seq
,
prms
,
seq
.
output_logprobs
[
-
1
],
all_input_ids
)
(
new_tokens
,
new_output_text
,
prefix_offset
,
(
new_tokens
,
new_output_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
read_offset
)
=
detokenize_incrementally
(
self
.
get_tokenizer_for_seq
(
seq
),
self
.
get_tokenizer_for_seq
(
seq
),
all_input_ids
=
seq
.
get_token
_ids
()
,
all_input_ids
=
all_input
_ids
,
prev_tokens
=
seq
.
tokens
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
read_offset
=
seq
.
read_offset
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
22de4523
...
@@ -82,8 +82,12 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -82,8 +82,12 @@ class OpenAIServingChat(OpenAIServing):
return
self
.
chat_completion_stream_generator
(
return
self
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
)
request
,
result_generator
,
request_id
)
else
:
else
:
return
await
self
.
chat_completion_full_generator
(
try
:
request
,
raw_request
,
result_generator
,
request_id
)
return
await
self
.
chat_completion_full_generator
(
request
,
raw_request
,
result_generator
,
request_id
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
def
get_chat_request_role
(
self
,
request
:
ChatCompletionRequest
)
->
str
:
def
get_chat_request_role
(
self
,
request
:
ChatCompletionRequest
)
->
str
:
if
request
.
add_generation_prompt
:
if
request
.
add_generation_prompt
:
...
@@ -99,117 +103,133 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -99,117 +103,133 @@ class OpenAIServingChat(OpenAIServing):
model_name
=
request
.
model
model_name
=
request
.
model
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
monotonic
())
chunk_object_type
=
"chat.completion.chunk"
chunk_object_type
=
"chat.completion.chunk"
first_iteration
=
True
# Send first response for each request.n (index) with the role
role
=
self
.
get_chat_request_role
(
request
)
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
role
),
logprobs
=
None
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# Send response to echo the input portion of the last message
if
request
.
echo
:
last_msg_content
=
""
if
request
.
messages
and
isinstance
(
request
.
messages
,
list
)
and
request
.
messages
[
-
1
].
get
(
"content"
)
and
request
.
messages
[
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
request
.
messages
[
-
1
][
"content"
]
if
last_msg_content
:
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
last_msg_content
),
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
logprobs
=
None
,
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# Send response for each token for each request.n (index)
# Send response for each token for each request.n (index)
previous_texts
=
[
""
]
*
request
.
n
previous_texts
=
[
""
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
previous_num_tokens
=
[
0
]
*
request
.
n
finish_reason_sent
=
[
False
]
*
request
.
n
finish_reason_sent
=
[
False
]
*
request
.
n
async
for
res
in
result_generator
:
try
:
res
:
RequestOutput
async
for
res
in
result_generator
:
for
output
in
res
.
outputs
:
res
:
RequestOutput
i
=
output
.
index
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
if
finish_reason_sent
[
i
]:
# response (by the try...catch).
continue
if
first_iteration
:
# Send first response for each request.n (index) with the role
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
role
=
self
.
get_chat_request_role
(
request
)
top_logprobs
=
output
.
logprobs
[
for
i
in
range
(
request
.
n
):
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
if
request
.
logprobs
:
delta
=
DeltaMessage
(
role
=
role
),
logprobs
=
self
.
_create_logprobs
(
logprobs
=
None
,
token_ids
=
delta_token_ids
,
finish_reason
=
None
)
top_logprobs
=
top_logprobs
,
chunk
=
ChatCompletionStreamResponse
(
num_output_top_logprobs
=
request
.
logprobs
,
id
=
request_id
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
object
=
chunk_object_type
,
)
created
=
created_time
,
else
:
choices
=
[
choice_data
],
logprobs
=
None
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
yield
f
"data:
{
data
}
\n\n
"
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
# Send response to echo the input portion of the last message
if
output
.
finish_reason
is
None
:
if
request
.
echo
:
# Send token-by-token response for each request.n
last_msg_content
=
""
choice_data
=
ChatCompletionResponseStreamChoice
(
if
request
.
messages
and
isinstance
(
index
=
i
,
request
.
messages
,
delta
=
DeltaMessage
(
content
=
delta_text
),
list
)
and
request
.
messages
[
-
1
].
get
(
logprobs
=
logprobs
,
"content"
)
and
request
.
messages
[
-
1
].
get
(
finish_reason
=
None
)
"role"
)
==
role
:
chunk
=
ChatCompletionStreamResponse
(
last_msg_content
=
request
.
messages
[
-
1
][
"content"
]
id
=
request_id
,
object
=
chunk_object_type
,
if
last_msg_content
:
created
=
created_time
,
for
i
in
range
(
request
.
n
):
choices
=
[
choice_data
],
choice_data
=
ChatCompletionResponseStreamChoice
(
model
=
model_name
)
index
=
i
,
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
delta
=
DeltaMessage
(
yield
f
"data:
{
data
}
\n\n
"
content
=
last_msg_content
),
else
:
finish_reason
=
None
)
# Send the finish response for each request.n only once
chunk
=
ChatCompletionStreamResponse
(
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
id
=
request_id
,
final_usage
=
UsageInfo
(
object
=
chunk_object_type
,
prompt_tokens
=
prompt_tokens
,
created
=
created_time
,
completion_tokens
=
previous_num_tokens
[
i
],
choices
=
[
choice_data
],
total_tokens
=
prompt_tokens
+
previous_num_tokens
[
i
],
logprobs
=
None
,
)
model
=
model_name
)
choice_data
=
ChatCompletionResponseStreamChoice
(
data
=
chunk
.
model_dump_json
(
index
=
i
,
exclude_unset
=
True
)
delta
=
DeltaMessage
(
content
=
delta_text
),
yield
f
"data:
{
data
}
\n\n
"
logprobs
=
logprobs
,
first_iteration
=
False
finish_reason
=
output
.
finish_reason
)
chunk
=
ChatCompletionStreamResponse
(
for
output
in
res
.
outputs
:
id
=
request_id
,
i
=
output
.
index
object
=
chunk_object_type
,
created
=
created_time
,
if
finish_reason_sent
[
i
]:
choices
=
[
choice_data
],
continue
model
=
model_name
)
if
final_usage
is
not
None
:
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
chunk
.
usage
=
final_usage
top_logprobs
=
output
.
logprobs
[
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
,
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
exclude_none
=
True
)
yield
f
"data:
{
data
}
\n\n
"
if
request
.
logprobs
:
finish_reason_sent
[
i
]
=
True
logprobs
=
self
.
_create_logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
if
output
.
finish_reason
is
None
:
# Send token-by-token response for each request.n
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
else
:
# Send the finish response for each request.n only once
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
previous_num_tokens
[
i
],
total_tokens
=
prompt_tokens
+
previous_num_tokens
[
i
],
)
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
content
=
delta_text
),
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
if
final_usage
is
not
None
:
chunk
.
usage
=
final_usage
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
,
exclude_none
=
True
)
yield
f
"data:
{
data
}
\n\n
"
finish_reason_sent
[
i
]
=
True
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
data
=
self
.
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
data
}
\n\n
"
# Send the final done message after all response.n are finished
# Send the final done message after all response.n are finished
yield
"data: [DONE]
\n\n
"
yield
"data: [DONE]
\n\n
"
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
22de4523
...
@@ -26,107 +26,6 @@ TypeCreateLogProbsFn = Callable[
...
@@ -26,107 +26,6 @@ TypeCreateLogProbsFn = Callable[
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
LogProbs
]
[
TypeTokenIDs
,
TypeTopLogProbs
,
Optional
[
int
],
int
],
LogProbs
]
async
def
completion_stream_generator
(
request
:
CompletionRequest
,
raw_request
:
Request
,
on_abort
,
result_generator
:
AsyncIterator
[
Tuple
[
int
,
RequestOutput
]],
create_logprobs_fn
:
TypeCreateLogProbsFn
,
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
num_prompts
:
int
,
)
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
*
num_prompts
previous_num_tokens
=
[
0
]
*
request
.
n
*
num_prompts
has_echoed
=
[
False
]
*
request
.
n
*
num_prompts
async
for
prompt_idx
,
res
in
result_generator
:
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
on_abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
i
=
output
.
index
+
prompt_idx
*
request
.
n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if
request
.
echo
and
request
.
max_tokens
==
0
:
# only return the prompt
delta_text
=
res
.
prompt
delta_token_ids
=
res
.
prompt_token_ids
top_logprobs
=
res
.
prompt_logprobs
has_echoed
[
i
]
=
True
elif
request
.
echo
and
request
.
max_tokens
>
0
and
not
has_echoed
[
i
]:
# echo the prompt and first token
delta_text
=
res
.
prompt
+
output
.
text
delta_token_ids
=
res
.
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
res
.
prompt_logprobs
+
(
output
.
logprobs
or
[])
has_echoed
[
i
]
=
True
else
:
# return just the delta
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
assert
top_logprobs
is
not
None
,
"top_logprobs must be provided when logprobs is requested"
logprobs
=
create_logprobs_fn
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
text
=
delta_text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
)
]).
model_dump_json
()
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
# return final usage
logprobs
=
LogProbs
()
if
request
.
logprobs
is
not
None
else
None
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
text
=
""
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
],
usage
=
final_usage
,
).
model_dump_json
()
yield
f
"data:
{
response_json
}
\n\n
"
yield
"data: [DONE]
\n\n
"
def
parse_prompt_format
(
prompt
)
->
Tuple
[
bool
,
list
]:
def
parse_prompt_format
(
prompt
)
->
Tuple
[
bool
,
list
]:
# get the prompt, openai supports the following
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
# "a string, array of strings, array of tokens, or array of token arrays."
...
@@ -151,73 +50,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
...
@@ -151,73 +50,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
return
prompt_is_tokens
,
prompts
return
prompt_is_tokens
,
prompts
def
request_output_to_completion_response
(
final_res_batch
:
List
[
RequestOutput
],
request
:
CompletionRequest
,
create_logprobs_fn
:
TypeCreateLogProbsFn
,
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
CompletionResponse
:
choices
=
[]
num_prompt_tokens
=
0
num_generated_tokens
=
0
for
final_res
in
final_res_batch
:
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
for
output
in
final_res
.
outputs
:
if
request
.
echo
and
request
.
max_tokens
==
0
:
token_ids
=
prompt_token_ids
top_logprobs
=
prompt_logprobs
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
prompt_logprobs
+
output
.
logprobs
output_text
=
prompt_text
+
output
.
text
else
:
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
output_text
=
output
.
text
if
request
.
logprobs
is
not
None
:
logprobs
=
create_logprobs_fn
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
len
(
choices
),
text
=
output_text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
num_generated_tokens
+=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
return
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
def
merge_async_iterators
(
*
iterators
):
def
merge_async_iterators
(
*
iterators
):
"""Merge multiple asynchronous iterators into a single iterator.
"""Merge multiple asynchronous iterators into a single iterator.
...
@@ -230,8 +62,11 @@ def merge_async_iterators(*iterators):
...
@@ -230,8 +62,11 @@ def merge_async_iterators(*iterators):
finished
=
[
False
]
*
len
(
iterators
)
finished
=
[
False
]
*
len
(
iterators
)
async
def
producer
(
i
,
iterator
):
async
def
producer
(
i
,
iterator
):
async
for
item
in
iterator
:
try
:
await
queue
.
put
((
i
,
item
))
async
for
item
in
iterator
:
await
queue
.
put
((
i
,
item
))
except
Exception
as
e
:
await
queue
.
put
(
e
)
finished
[
i
]
=
True
finished
[
i
]
=
True
_tasks
=
[
_tasks
=
[
...
@@ -242,6 +77,8 @@ def merge_async_iterators(*iterators):
...
@@ -242,6 +77,8 @@ def merge_async_iterators(*iterators):
async
def
consumer
():
async
def
consumer
():
while
not
all
(
finished
)
or
not
queue
.
empty
():
while
not
all
(
finished
)
or
not
queue
.
empty
():
item
=
await
queue
.
get
()
item
=
await
queue
.
get
()
if
isinstance
(
item
,
Exception
):
raise
item
yield
item
yield
item
await
asyncio
.
gather
(
*
_tasks
)
await
asyncio
.
gather
(
*
_tasks
)
...
@@ -312,6 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -312,6 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_token_ids
=
input_ids
,
prompt_token_ids
=
input_ids
,
lora_request
=
lora_request
))
lora_request
=
lora_request
))
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
Tuple
[
result_generator
:
AsyncIterator
[
Tuple
[
...
@@ -325,27 +163,28 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -325,27 +163,28 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
# Streaming response
if
stream
:
if
stream
:
return
completion_stream_generator
(
request
,
return
self
.
completion_stream_generator
(
request
,
raw_request
,
raw_request
,
self
.
engine
.
abort
,
result_generator
,
result_generator
,
request_id
,
self
.
_create_logprobs
,
created_time
,
request_id
,
model_name
,
created_time
,
num_prompts
=
len
(
prompts
))
model_name
,
num_prompts
=
len
(
prompts
))
# Non-streaming response
# Non-streaming response
final_res_batch
:
RequestOutput
=
[
None
]
*
len
(
prompts
)
final_res_batch
:
RequestOutput
=
[
None
]
*
len
(
prompts
)
async
for
i
,
res
in
result_generator
:
try
:
if
await
raw_request
.
is_disconnected
():
async
for
i
,
res
in
result_generator
:
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
# Abort the request if the client disconnects.
return
self
.
create_error_response
(
"Client disconnected"
)
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
final_res_batch
[
i
]
=
res
return
self
.
create_error_response
(
"Client disconnected"
)
response
=
request_output_to_completion_response
(
final_res_batch
[
i
]
=
res
final_res_batch
,
request
,
self
.
_create_logprobs
,
request_id
,
response
=
self
.
request_output_to_completion_response
(
created_time
,
model_name
)
final_res_batch
,
request
,
request_id
,
created_time
,
model_name
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
# When user requests streaming but we don't stream, we still need to
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
# return a streaming response with a single event.
...
@@ -359,3 +198,179 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -359,3 +198,179 @@ class OpenAIServingCompletion(OpenAIServing):
return
fake_stream_generator
()
return
fake_stream_generator
()
return
response
return
response
async
def
completion_stream_generator
(
self
,
request
:
CompletionRequest
,
raw_request
:
Request
,
result_generator
:
AsyncIterator
[
Tuple
[
int
,
RequestOutput
]],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
num_prompts
:
int
,
)
->
AsyncGenerator
[
str
,
None
]:
previous_texts
=
[
""
]
*
request
.
n
*
num_prompts
previous_num_tokens
=
[
0
]
*
request
.
n
*
num_prompts
has_echoed
=
[
False
]
*
request
.
n
*
num_prompts
try
:
async
for
prompt_idx
,
res
in
result_generator
:
# Abort the request if the client disconnects.
if
await
raw_request
.
is_disconnected
():
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
prompt_idx
}
"
)
raise
StopAsyncIteration
()
for
output
in
res
.
outputs
:
i
=
output
.
index
+
prompt_idx
*
request
.
n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if
request
.
echo
and
request
.
max_tokens
==
0
:
# only return the prompt
delta_text
=
res
.
prompt
delta_token_ids
=
res
.
prompt_token_ids
top_logprobs
=
res
.
prompt_logprobs
has_echoed
[
i
]
=
True
elif
request
.
echo
and
request
.
max_tokens
>
0
and
not
has_echoed
[
i
]:
# echo the prompt and first token
delta_text
=
res
.
prompt
+
output
.
text
delta_token_ids
=
res
.
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
res
.
prompt_logprobs
+
(
output
.
logprobs
or
[])
has_echoed
[
i
]
=
True
else
:
# return just the delta
delta_text
=
output
.
text
[
len
(
previous_texts
[
i
]):]
delta_token_ids
=
output
.
token_ids
[
previous_num_tokens
[
i
]:]
top_logprobs
=
output
.
logprobs
[
previous_num_tokens
[
i
]:]
if
output
.
logprobs
else
None
if
request
.
logprobs
is
not
None
:
assert
top_logprobs
is
not
None
,
"top_logprobs must be provided when logprobs is requested"
logprobs
=
self
.
_create_logprobs
(
token_ids
=
delta_token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
initial_text_offset
=
len
(
previous_texts
[
i
]),
)
else
:
logprobs
=
None
previous_texts
[
i
]
=
output
.
text
previous_num_tokens
[
i
]
=
len
(
output
.
token_ids
)
finish_reason
=
output
.
finish_reason
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
text
=
delta_text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
,
)
]).
model_dump_json
()
yield
f
"data:
{
response_json
}
\n\n
"
if
output
.
finish_reason
is
not
None
:
# return final usage
logprobs
=
LogProbs
(
)
if
request
.
logprobs
is
not
None
else
None
prompt_tokens
=
len
(
res
.
prompt_token_ids
)
completion_tokens
=
len
(
output
.
token_ids
)
final_usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
response_json
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
i
,
text
=
""
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
],
usage
=
final_usage
,
).
model_dump_json
()
yield
f
"data:
{
response_json
}
\n\n
"
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
data
=
self
.
create_streaming_error_response
(
str
(
e
))
print
(
"yield"
,
f
"data:
{
data
}
\n\n
"
)
yield
f
"data:
{
data
}
\n\n
"
print
(
"yield"
,
"data: [DONE]
\n\n
"
)
yield
"data: [DONE]
\n\n
"
def
request_output_to_completion_response
(
self
,
final_res_batch
:
List
[
RequestOutput
],
request
:
CompletionRequest
,
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
CompletionResponse
:
choices
=
[]
num_prompt_tokens
=
0
num_generated_tokens
=
0
for
final_res
in
final_res_batch
:
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
prompt_logprobs
=
final_res
.
prompt_logprobs
prompt_text
=
final_res
.
prompt
for
output
in
final_res
.
outputs
:
if
request
.
echo
and
request
.
max_tokens
==
0
:
token_ids
=
prompt_token_ids
top_logprobs
=
prompt_logprobs
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
prompt_logprobs
+
output
.
logprobs
output_text
=
prompt_text
+
output
.
text
else
:
token_ids
=
output
.
token_ids
top_logprobs
=
output
.
logprobs
output_text
=
output
.
text
if
request
.
logprobs
is
not
None
:
logprobs
=
self
.
_create_logprobs
(
token_ids
=
token_ids
,
top_logprobs
=
top_logprobs
,
num_output_top_logprobs
=
request
.
logprobs
,
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
len
(
choices
),
text
=
output_text
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
,
)
choices
.
append
(
choice_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
num_generated_tokens
+=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
return
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
vllm/entrypoints/openai/serving_engine.py
View file @
22de4523
import
asyncio
import
asyncio
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
...
@@ -11,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
...
@@ -11,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
ModelCard
,
ModelList
,
ModelCard
,
ModelList
,
ModelPermission
)
ModelPermission
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -83,7 +85,7 @@ class OpenAIServing:
...
@@ -83,7 +85,7 @@ class OpenAIServing:
def
_create_logprobs
(
def
_create_logprobs
(
self
,
self
,
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
top_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
float
]]]]
=
None
,
top_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
Logprob
]]]]
=
None
,
num_output_top_logprobs
:
Optional
[
int
]
=
None
,
num_output_top_logprobs
:
Optional
[
int
]
=
None
,
initial_text_offset
:
int
=
0
,
initial_text_offset
:
int
=
0
,
)
->
LogProbs
:
)
->
LogProbs
:
...
@@ -95,10 +97,10 @@ class OpenAIServing:
...
@@ -95,10 +97,10 @@ class OpenAIServing:
for
i
,
token_id
in
enumerate
(
token_ids
):
for
i
,
token_id
in
enumerate
(
token_ids
):
step_top_logprobs
=
top_logprobs
[
i
]
step_top_logprobs
=
top_logprobs
[
i
]
if
step_top_logprobs
is
not
None
:
if
step_top_logprobs
is
not
None
:
token_logprob
=
step_top_logprobs
[
token_id
]
token_logprob
=
step_top_logprobs
[
token_id
]
.
logprob
else
:
else
:
token_logprob
=
None
token_logprob
=
None
token
=
s
elf
.
tokenizer
.
convert_ids_to_tokens
(
token_id
)
token
=
s
tep_top_logprobs
[
token_id
].
decoded_token
logprobs
.
tokens
.
append
(
token
)
logprobs
.
tokens
.
append
(
token
)
logprobs
.
token_logprobs
.
append
(
token_logprob
)
logprobs
.
token_logprobs
.
append
(
token_logprob
)
if
len
(
logprobs
.
text_offset
)
==
0
:
if
len
(
logprobs
.
text_offset
)
==
0
:
...
@@ -110,7 +112,7 @@ class OpenAIServing:
...
@@ -110,7 +112,7 @@ class OpenAIServing:
if
num_output_top_logprobs
:
if
num_output_top_logprobs
:
logprobs
.
top_logprobs
.
append
({
logprobs
.
top_logprobs
.
append
({
self
.
tokenizer
.
convert_ids_to_tokens
(
i
):
p
p
.
decoded_token
:
p
.
logprob
for
i
,
p
in
step_top_logprobs
.
items
()
for
i
,
p
in
step_top_logprobs
.
items
()
}
if
step_top_logprobs
else
None
)
}
if
step_top_logprobs
else
None
)
return
logprobs
return
logprobs
...
@@ -124,6 +126,19 @@ class OpenAIServing:
...
@@ -124,6 +126,19 @@ class OpenAIServing:
type
=
err_type
,
type
=
err_type
,
code
=
status_code
.
value
)
code
=
status_code
.
value
)
def
create_streaming_error_response
(
self
,
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
)
->
str
:
json_str
=
json
.
dumps
({
"error"
:
self
.
create_error_response
(
message
=
message
,
err_type
=
err_type
,
status_code
=
status_code
).
model_dump
()
})
return
json_str
async
def
_check_model
(
self
,
request
)
->
Optional
[
ErrorResponse
]:
async
def
_check_model
(
self
,
request
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
==
self
.
served_model
:
if
request
.
model
==
self
.
served_model
:
return
return
...
...
vllm/model_executor/layers/sampler.py
View file @
22de4523
...
@@ -8,8 +8,9 @@ from vllm.model_executor.parallel_utils.communication_op import (
...
@@ -8,8 +8,9 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather
)
tensor_model_parallel_gather
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.utils
import
is_neuron
from
vllm.utils
import
is_neuron
...
@@ -528,7 +529,10 @@ def _get_logprobs(
...
@@ -528,7 +529,10 @@ def _get_logprobs(
prompt_logprobs_dict
.
update
(
prompt_logprobs_dict
.
update
(
zip
(
top_token_ids
[
sample_idx
,
:
num_logprobs
].
tolist
(),
zip
(
top_token_ids
[
sample_idx
,
:
num_logprobs
].
tolist
(),
top_logprobs
[
sample_idx
,
:
num_logprobs
].
tolist
()))
top_logprobs
[
sample_idx
,
:
num_logprobs
].
tolist
()))
group_prompt_logprobs
.
append
(
prompt_logprobs_dict
)
group_prompt_logprobs
.
append
({
token_id
:
Logprob
(
logprob
)
for
token_id
,
logprob
in
prompt_logprobs_dict
.
items
()
})
sample_idx
+=
1
sample_idx
+=
1
query_result_idx
+=
1
query_result_idx
+=
1
result_prompt_logprobs
.
append
(
group_prompt_logprobs
)
result_prompt_logprobs
.
append
(
group_prompt_logprobs
)
...
@@ -553,7 +557,10 @@ def _get_logprobs(
...
@@ -553,7 +557,10 @@ def _get_logprobs(
parent_id
,
:
num_logprobs
].
tolist
(),
parent_id
,
:
num_logprobs
].
tolist
(),
top_logprobs
[
sample_idx
+
top_logprobs
[
sample_idx
+
parent_id
,
:
num_logprobs
].
tolist
()))
parent_id
,
:
num_logprobs
].
tolist
()))
group_sample_logprobs
.
append
(
sample_logprobs_dict
)
group_sample_logprobs
.
append
({
token_id
:
Logprob
(
logprob
)
for
token_id
,
logprob
in
sample_logprobs_dict
.
items
()
})
result_sample_logprobs
.
append
(
group_sample_logprobs
)
result_sample_logprobs
.
append
(
group_sample_logprobs
)
sample_idx
+=
len
(
seq_ids
)
sample_idx
+=
len
(
seq_ids
)
...
...
vllm/sequence.py
View file @
22de4523
...
@@ -8,8 +8,16 @@ from vllm.block import LogicalTokenBlock
...
@@ -8,8 +8,16 @@ from vllm.block import LogicalTokenBlock
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
PromptLogprobs
=
List
[
Optional
[
Dict
[
int
,
float
]]]
SampleLogprobs
=
List
[
Dict
[
int
,
float
]]
@
dataclass
class
Logprob
:
"""Infos for supporting OpenAI compatible logprobs."""
logprob
:
float
decoded_token
:
Optional
[
str
]
=
None
PromptLogprobs
=
List
[
Optional
[
Dict
[
int
,
Logprob
]]]
SampleLogprobs
=
List
[
Dict
[
int
,
Logprob
]]
class
SequenceStatus
(
enum
.
Enum
):
class
SequenceStatus
(
enum
.
Enum
):
...
@@ -196,12 +204,12 @@ class Sequence:
...
@@ -196,12 +204,12 @@ class Sequence:
def
append_token_id
(
def
append_token_id
(
self
,
self
,
token_id
:
int
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
float
],
logprobs
:
Dict
[
int
,
Logprob
],
)
->
None
:
)
->
None
:
assert
token_id
in
logprobs
assert
token_id
in
logprobs
self
.
_append_tokens_to_blocks
([
token_id
])
self
.
_append_tokens_to_blocks
([
token_id
])
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
])
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
]
.
logprob
)
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
self
.
data
.
get_len
()
return
self
.
data
.
get_len
()
...
@@ -456,7 +464,7 @@ class SequenceOutput:
...
@@ -456,7 +464,7 @@ class SequenceOutput:
self
,
self
,
parent_seq_id
:
int
,
parent_seq_id
:
int
,
output_token
:
int
,
output_token
:
int
,
logprobs
:
Dict
[
int
,
float
],
logprobs
:
Dict
[
int
,
Logprob
],
)
->
None
:
)
->
None
:
self
.
parent_seq_id
=
parent_seq_id
self
.
parent_seq_id
=
parent_seq_id
self
.
output_token
=
output_token
self
.
output_token
=
output_token
...
@@ -470,9 +478,10 @@ class SequenceOutput:
...
@@ -470,9 +478,10 @@ class SequenceOutput:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
SequenceOutput
):
if
not
isinstance
(
other
,
SequenceOutput
):
raise
NotImplementedError
()
raise
NotImplementedError
()
return
(
self
.
parent_seq_id
==
other
.
parent_seq_id
equal
=
(
self
.
parent_seq_id
==
other
.
parent_seq_id
and
self
.
output_token
==
other
.
output_token
and
self
.
output_token
==
other
.
output_token
)
and
self
.
logprobs
==
other
.
logprobs
)
log_probs_equal
=
other
.
logprobs
==
self
.
logprobs
return
equal
and
log_probs_equal
class
SequenceGroupOutput
:
class
SequenceGroupOutput
:
...
...
vllm/worker/spec_decode/multi_step_worker.py
View file @
22de4523
...
@@ -77,7 +77,7 @@ class MultiStepWorker(Worker):
...
@@ -77,7 +77,7 @@ class MultiStepWorker(Worker):
token_id
=
seq_output
.
output_token
token_id
=
seq_output
.
output_token
token_logprob
=
seq_output
.
logprobs
[
token_id
]
token_logprob
=
seq_output
.
logprobs
[
token_id
]
seq
.
append_token_id
(
token_id
,
token_logprob
)
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
def
_shallow_copy_inputs
(
def
_shallow_copy_inputs
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment