Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c10591e
"docs/vscode:/vscode.git/clone" did not exist on "0d9eb99ddd616798c2582c72aaa375dc76fddbbd"
Unverified
Commit
3c10591e
authored
Jul 31, 2024
by
zifeitong
Committed by
GitHub
Jul 31, 2024
Browse files
[Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (#6954)
parent
0437492e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
92 additions
and
44 deletions
+92
-44
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+39
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+23
-7
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+8
-15
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+9
-18
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+13
-4
No files found.
tests/entrypoints/openai/test_serving_chat.py
View file @
3c10591e
import
asyncio
import
asyncio
from
contextlib
import
suppress
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
unittest.mock
import
MagicMock
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
MODEL_NAME
=
"openai-community/gpt2"
MODEL_NAME
=
"openai-community/gpt2"
CHAT_TEMPLATE
=
"Dummy chat template for testing {}"
CHAT_TEMPLATE
=
"Dummy chat template for testing {}"
...
@@ -42,3 +47,37 @@ async def _async_serving_chat_init():
...
@@ -42,3 +47,37 @@ async def _async_serving_chat_init():
def
test_async_serving_chat_init
():
def
test_async_serving_chat_init
():
serving_completion
=
asyncio
.
run
(
_async_serving_chat_init
())
serving_completion
=
asyncio
.
run
(
_async_serving_chat_init
())
assert
serving_completion
.
chat_template
==
CHAT_TEMPLATE
assert
serving_completion
.
chat_template
==
CHAT_TEMPLATE
def
test_serving_chat_should_set_correct_max_tokens
():
mock_engine
=
MagicMock
(
spec
=
AsyncLLMEngine
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
serving_chat
=
OpenAIServingChat
(
mock_engine
,
MockModelConfig
(),
served_model_names
=
[
MODEL_NAME
],
response_role
=
"assistant"
,
chat_template
=
CHAT_TEMPLATE
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
None
)
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}],
guided_decoding_backend
=
"outlines"
,
)
with
suppress
(
Exception
):
asyncio
.
run
(
serving_chat
.
create_chat_completion
(
req
))
# AsyncLLMEngine.generate(inputs, sampling_params, ...)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
req
.
max_tokens
=
10
with
suppress
(
Exception
):
asyncio
.
run
(
serving_chat
.
create_chat_completion
(
req
))
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
vllm/entrypoints/openai/protocol.py
View file @
3c10591e
...
@@ -11,7 +11,7 @@ from typing_extensions import Annotated
...
@@ -11,7 +11,7 @@ from typing_extensions import Annotated
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
...
@@ -215,15 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -215,15 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
# doc: end-chat-completion-extra-params
def
to_sampling_params
(
self
,
def
to_sampling_params
(
tokenizer
:
PreTrainedTokenizer
)
->
SamplingParams
:
self
,
tokenizer
:
PreTrainedTokenizer
,
# We now allow logprobs being true without top_logrobs.
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
# We now allow logprobs being true without top_logrobs.
logits_processors
=
get_logits_processors
(
logits_processors
=
get_logits_processors
(
logit_bias
=
self
.
logit_bias
,
logit_bias
=
self
.
logit_bias
,
allowed_token_ids
=
None
,
allowed_token_ids
=
None
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
)
)
if
guided_decode_logits_processor
:
logits_processors
.
append
(
guided_decode_logits_processor
)
return
SamplingParams
(
return
SamplingParams
(
n
=
self
.
n
,
n
=
self
.
n
,
...
@@ -241,7 +248,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -241,7 +248,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
prompt_logprobs
=
self
.
top_logprobs
if
self
.
echo
else
None
,
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
,
max_tokens
=
max_tokens
,
min_tokens
=
self
.
min_tokens
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
early_stopping
=
self
.
early_stopping
,
...
@@ -395,7 +402,14 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -395,7 +402,14 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
# doc: end-completion-extra-params
def
to_sampling_params
(
self
,
tokenizer
:
PreTrainedTokenizer
):
def
to_sampling_params
(
self
,
tokenizer
:
PreTrainedTokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
logits_processors
=
get_logits_processors
(
logits_processors
=
get_logits_processors
(
...
@@ -403,6 +417,8 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -403,6 +417,8 @@ class CompletionRequest(OpenAIBaseModel):
allowed_token_ids
=
self
.
allowed_token_ids
,
allowed_token_ids
=
self
.
allowed_token_ids
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
)
)
if
guided_decode_logits_processor
:
logits_processors
.
append
(
guided_decode_logits_processor
)
return
SamplingParams
(
return
SamplingParams
(
n
=
self
.
n
,
n
=
self
.
n
,
...
@@ -419,7 +435,7 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -419,7 +435,7 @@ class CompletionRequest(OpenAIBaseModel):
stop_token_ids
=
self
.
stop_token_ids
,
stop_token_ids
=
self
.
stop_token_ids
,
logprobs
=
self
.
logprobs
,
logprobs
=
self
.
logprobs
,
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
if
not
echo_without_generation
else
1
,
max_tokens
=
max_tokens
if
not
echo_without_generation
else
1
,
min_tokens
=
self
.
min_tokens
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
early_stopping
=
self
.
early_stopping
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
3c10591e
...
@@ -25,8 +25,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
...
@@ -25,8 +25,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath
)
PromptAdapterPath
)
from
vllm.inputs
import
PromptInputs
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
...
@@ -134,28 +132,23 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -134,28 +132,23 @@ class OpenAIServingChat(OpenAIServing):
request_id
=
f
"chat-
{
random_uuid
()
}
"
request_id
=
f
"chat-
{
random_uuid
()
}
"
try
:
try
:
sampling_params
=
request
.
to_sampling_params
(
tokenizer
)
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logits_processor
=
(
guided_decode_logits_processor
=
(
await
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
))
if
guided_decode_logits_processor
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
guided_decode_logits_processor
)
prompt_inputs
=
self
.
_tokenize_prompt_input
(
prompt_inputs
=
self
.
_tokenize_prompt_input
(
request
,
request
,
tokenizer
,
tokenizer
,
prompt
,
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
)
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
guided_decode_logits_processor
,
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
prompt_inputs
,
prompt_inputs
,
params
=
sampling_params
,
params
=
sampling_params
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
3c10591e
...
@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
...
@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing
,
OpenAIServing
,
PromptAdapterPath
)
PromptAdapterPath
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
@@ -95,31 +93,24 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -95,31 +93,24 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
sampling_params
=
request
.
to_sampling_params
(
tokenizer
)
guided_decode_logits_processor
=
(
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logit_processor
=
(
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
))
if
guided_decode_logit_processor
is
not
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
guided_decode_logit_processor
)
prompts
=
list
(
prompts
=
list
(
self
.
_tokenize_prompt_input_or_inputs
(
self
.
_tokenize_prompt_input_or_inputs
(
request
,
request
,
tokenizer
,
tokenizer
,
request
.
prompt
,
request
.
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
))
))
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
guided_decode_logits_processor
,
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
self
.
_log_inputs
(
request_id_item
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
3c10591e
...
@@ -25,9 +25,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -25,9 +25,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer_group
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
AnyTokenizer
...
@@ -150,6 +152,15 @@ class OpenAIServing:
...
@@ -150,6 +152,15 @@ class OpenAIServing:
})
})
return
json_str
return
json_str
async
def
_guided_decode_logits_processor
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
tokenizer
:
AnyTokenizer
)
->
Optional
[
LogitsProcessor
]:
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
return
await
get_guided_decoding_logits_processor
(
guided_decoding_backend
,
request
,
tokenizer
)
async
def
_check_model
(
async
def
_check_model
(
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
...
@@ -254,9 +265,7 @@ class OpenAIServing:
...
@@ -254,9 +265,7 @@ class OpenAIServing:
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the messages, "
f
"
{
token_num
}
tokens in the messages, "
f
"Please reduce the length of the messages."
)
f
"Please reduce the length of the messages."
)
request
.
max_tokens
=
self
.
max_model_len
-
token_num
elif
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
if
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
raise
ValueError
(
raise
ValueError
(
f
"This model's maximum context length is "
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment