Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a4113b03
Unverified
Commit
a4113b03
authored
Jul 03, 2025
by
Gabriel Marinho
Committed by
GitHub
Jul 04, 2025
Browse files
[Platform] Add custom default max tokens (#18557)
Signed-off-by:
Gabriel Marinho
<
gmarinho@ibm.com
>
parent
7e1665b0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
59 additions
and
60 deletions
+59
-60
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+10
-49
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+13
-5
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+12
-4
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+20
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+4
-0
No files found.
vllm/entrypoints/openai/protocol.py
View file @
a4113b03
...
@@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias
:
Optional
[
dict
[
str
,
float
]]
=
None
logit_bias
:
Optional
[
dict
[
str
,
float
]]
=
None
logprobs
:
Optional
[
bool
]
=
False
logprobs
:
Optional
[
bool
]
=
False
top_logprobs
:
Optional
[
int
]
=
0
top_logprobs
:
Optional
[
int
]
=
0
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
:
Optional
[
int
]
=
Field
(
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
None
,
default
=
None
,
deprecated
=
deprecated
=
...
@@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
}
}
def
to_beam_search_params
(
def
to_beam_search_params
(
self
,
self
,
max_tokens
:
int
,
default_max_tokens
:
int
,
default_sampling_params
:
dict
)
->
BeamSearchParams
:
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
BeamSearchParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
n
=
self
.
n
if
self
.
n
is
not
None
else
1
n
=
self
.
n
if
self
.
n
is
not
None
else
1
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
if
(
temperature
:
=
self
.
temperature
)
is
None
:
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
temperature
=
default_sampling_params
.
get
(
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
...
@@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
def
to_sampling_params
(
def
to_sampling_params
(
self
,
self
,
default_
max_tokens
:
int
,
max_tokens
:
int
,
logits_processor_pattern
:
Optional
[
str
],
logits_processor_pattern
:
Optional
[
str
],
default_sampling_params
:
Optional
[
dict
]
=
None
,
default_sampling_params
:
dict
,
)
->
SamplingParams
:
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
# Default parameters
# Default parameters
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
...
@@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
}
}
def
to_beam_search_params
(
def
to_beam_search_params
(
self
,
self
,
default_
max_tokens
:
int
,
max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
default_sampling_params
:
Optional
[
dict
]
=
None
,
)
->
BeamSearchParams
:
)
->
BeamSearchParams
:
max_tokens
=
self
.
max_tokens
if
default_sampling_params
is
None
:
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
default_sampling_params
=
{}
n
=
self
.
n
if
self
.
n
is
not
None
else
1
n
=
self
.
n
if
self
.
n
is
not
None
else
1
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
if
(
temperature
:
=
self
.
temperature
)
is
None
:
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
1.0
)
temperature
=
default_sampling_params
.
get
(
"temperature"
,
1.0
)
...
@@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel):
def
to_sampling_params
(
def
to_sampling_params
(
self
,
self
,
default_
max_tokens
:
int
,
max_tokens
:
int
,
logits_processor_pattern
:
Optional
[
str
],
logits_processor_pattern
:
Optional
[
str
],
default_sampling_params
:
Optional
[
dict
]
=
None
,
default_sampling_params
:
Optional
[
dict
]
=
None
,
)
->
SamplingParams
:
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
default_sampling_params
is
None
:
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
default_sampling_params
=
{}
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
# Default parameters
# Default parameters
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
repetition_penalty
=
default_sampling_params
.
get
(
repetition_penalty
=
default_sampling_params
.
get
(
...
@@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel):
...
@@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel):
self
,
self
,
default_max_tokens
:
int
,
default_max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
default_max_tokens
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
if
default_sampling_params
is
None
:
...
@@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel):
...
@@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel):
self
,
self
,
default_max_tokens
:
int
,
default_max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
default_max_tokens
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
if
default_sampling_params
is
None
:
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
a4113b03
...
@@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
...
@@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
(
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
(
MistralToolCall
)
MistralToolCall
)
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
...
@@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing):
try
:
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
default_max_tokens
=
self
.
max_model_len
-
len
(
engine_prompt
[
"prompt_token_ids"
])
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
request
=
request
,
input_length
=
len
(
engine_prompt
[
"prompt_token_ids"
]),
default_sampling_params
=
self
.
default_sampling_params
)
if
request
.
use_beam_search
:
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
sampling_params
=
request
.
to_beam_search_params
(
default_
max_tokens
,
self
.
default_sampling_params
)
max_tokens
,
self
.
default_sampling_params
)
else
:
else
:
sampling_params
=
request
.
to_sampling_params
(
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
max_tokens
,
self
.
model_config
.
logits_processor_pattern
,
self
.
model_config
.
logits_processor_pattern
,
self
.
default_sampling_params
)
self
.
default_sampling_params
)
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
a4113b03
...
@@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
...
@@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
is_text_tokens_prompt
)
is_text_tokens_prompt
)
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.inputs.data
import
(
EmbedsPrompt
,
TokensPrompt
,
is_embeds_prompt
,
from
vllm.inputs.data
import
(
EmbedsPrompt
,
TokensPrompt
,
is_embeds_prompt
,
is_tokens_prompt
)
is_tokens_prompt
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing):
input_length
=
len
(
engine_prompt
[
"prompt_token_ids"
])
input_length
=
len
(
engine_prompt
[
"prompt_token_ids"
])
else
:
else
:
assert_never
(
engine_prompt
)
assert_never
(
engine_prompt
)
default_max_tokens
=
self
.
max_model_len
-
input_length
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
request
=
request
,
input_length
=
input_length
,
default_sampling_params
=
self
.
default_sampling_params
)
if
request
.
use_beam_search
:
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
sampling_params
=
request
.
to_beam_search_params
(
default_
max_tokens
,
self
.
default_sampling_params
)
max_tokens
,
self
.
default_sampling_params
)
else
:
else
:
sampling_params
=
request
.
to_sampling_params
(
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
max_tokens
,
self
.
model_config
.
logits_processor_pattern
,
self
.
model_config
.
logits_processor_pattern
,
self
.
default_sampling_params
)
self
.
default_sampling_params
)
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/entrypoints/utils.py
View file @
a4113b03
...
@@ -5,13 +5,17 @@ import argparse
...
@@ -5,13 +5,17 @@ import argparse
import
asyncio
import
asyncio
import
functools
import
functools
import
os
import
os
from
typing
import
Any
,
Optional
import
sys
from
typing
import
Any
,
Optional
,
Union
from
fastapi
import
Request
from
fastapi
import
Request
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
starlette.background
import
BackgroundTask
,
BackgroundTasks
from
starlette.background
import
BackgroundTask
,
BackgroundTasks
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -181,7 +185,6 @@ def _validate_truncation_size(
...
@@ -181,7 +185,6 @@ def _validate_truncation_size(
def
show_filtered_argument_or_group_from_help
(
parser
:
argparse
.
ArgumentParser
,
def
show_filtered_argument_or_group_from_help
(
parser
:
argparse
.
ArgumentParser
,
subcommand_name
:
list
[
str
]):
subcommand_name
:
list
[
str
]):
import
sys
# Only handle --help=<keyword> for the current subcommand.
# Only handle --help=<keyword> for the current subcommand.
# Since subparser_init() runs for all subcommands during CLI setup,
# Since subparser_init() runs for all subcommands during CLI setup,
...
@@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
...
@@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
print
(
f
"
\n
No group or parameter matching '
{
search_keyword
}
'"
)
print
(
f
"
\n
No group or parameter matching '
{
search_keyword
}
'"
)
print
(
"Tip: use `--help=listgroup` to view all groups."
)
print
(
"Tip: use `--help=listgroup` to view all groups."
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
def
get_max_tokens
(
max_model_len
:
int
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
input_length
:
int
,
default_sampling_params
:
dict
)
->
int
:
max_tokens
=
getattr
(
request
,
"max_completion_tokens"
,
None
)
or
request
.
max_tokens
default_max_tokens
=
max_model_len
-
input_length
max_output_tokens
=
current_platform
.
get_max_output_tokens
(
input_length
)
return
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
max_output_tokens
,
default_sampling_params
.
get
(
"max_tokens"
))
if
val
is
not
None
)
vllm/platforms/interface.py
View file @
a4113b03
...
@@ -4,6 +4,7 @@ import enum
...
@@ -4,6 +4,7 @@ import enum
import
os
import
os
import
platform
import
platform
import
random
import
random
import
sys
from
datetime
import
timedelta
from
datetime
import
timedelta
from
platform
import
uname
from
platform
import
uname
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
,
Union
...
@@ -164,6 +165,9 @@ class Platform:
...
@@ -164,6 +165,9 @@ class Platform:
def
is_out_of_tree
(
self
)
->
bool
:
def
is_out_of_tree
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
OOT
return
self
.
_enum
==
PlatformEnum
.
OOT
def
get_max_output_tokens
(
self
,
prompt_len
:
int
)
->
int
:
return
sys
.
maxsize
def
is_cuda_alike
(
self
)
->
bool
:
def
is_cuda_alike
(
self
)
->
bool
:
"""Stateless version of [torch.cuda.is_available][]."""
"""Stateless version of [torch.cuda.is_available][]."""
return
self
.
_enum
in
(
PlatformEnum
.
CUDA
,
PlatformEnum
.
ROCM
)
return
self
.
_enum
in
(
PlatformEnum
.
CUDA
,
PlatformEnum
.
ROCM
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment