Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a4113b03
Unverified
Commit
a4113b03
authored
Jul 03, 2025
by
Gabriel Marinho
Committed by
GitHub
Jul 04, 2025
Browse files
[Platform] Add custom default max tokens (#18557)
Signed-off-by:
Gabriel Marinho
<
gmarinho@ibm.com
>
parent
7e1665b0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
59 additions
and
60 deletions
+59
-60
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+10
-49
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+13
-5
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+12
-4
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+20
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+4
-0
No files found.
vllm/entrypoints/openai/protocol.py
View file @
a4113b03
...
...
@@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias
:
Optional
[
dict
[
str
,
float
]]
=
None
logprobs
:
Optional
[
bool
]
=
False
top_logprobs
:
Optional
[
int
]
=
0
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
None
,
deprecated
=
...
...
@@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
}
def
to_beam_search_params
(
self
,
default_max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
BeamSearchParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
self
,
max_tokens
:
int
,
default_sampling_params
:
dict
)
->
BeamSearchParams
:
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
n
=
self
.
n
if
self
.
n
is
not
None
else
1
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
...
...
@@ -465,21 +451,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
def
to_sampling_params
(
self
,
default_
max_tokens
:
int
,
max_tokens
:
int
,
logits_processor_pattern
:
Optional
[
str
],
default_sampling_params
:
Optional
[
dict
]
=
None
,
default_sampling_params
:
dict
,
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
self
.
max_completion_tokens
or
self
.
max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
# Default parameters
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
...
...
@@ -899,21 +874,14 @@ class CompletionRequest(OpenAIBaseModel):
def
to_beam_search_params
(
self
,
default_
max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
,
)
->
BeamSearchParams
:
max_tokens
=
self
.
max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
n
=
self
.
n
if
self
.
n
is
not
None
else
1
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
1.0
)
...
...
@@ -928,21 +896,14 @@ class CompletionRequest(OpenAIBaseModel):
def
to_sampling_params
(
self
,
default_
max_tokens
:
int
,
max_tokens
:
int
,
logits_processor_pattern
:
Optional
[
str
],
default_sampling_params
:
Optional
[
dict
]
=
None
,
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Use minimum of context window, user request & server limit.
max_tokens
=
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
default_sampling_params
.
get
(
"max_tokens"
,
None
))
if
val
is
not
None
)
# Default parameters
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
repetition_penalty
=
default_sampling_params
.
get
(
...
...
@@ -1813,7 +1774,7 @@ class TranscriptionRequest(OpenAIBaseModel):
self
,
default_max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
...
...
@@ -2029,7 +1990,7 @@ class TranslationRequest(OpenAIBaseModel):
self
,
default_max_tokens
:
int
,
default_sampling_params
:
Optional
[
dict
]
=
None
)
->
SamplingParams
:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens
=
default_max_tokens
if
default_sampling_params
is
None
:
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
a4113b03
...
...
@@ -34,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
(
MistralToolCall
)
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.logger
import
init_logger
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
...
...
@@ -233,15 +234,22 @@ class OpenAIServingChat(OpenAIServing):
try
:
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
default_max_tokens
=
self
.
max_model_len
-
len
(
engine_prompt
[
"prompt_token_ids"
])
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
request
=
request
,
input_length
=
len
(
engine_prompt
[
"prompt_token_ids"
]),
default_sampling_params
=
self
.
default_sampling_params
)
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
default_
max_tokens
,
self
.
default_sampling_params
)
max_tokens
,
self
.
default_sampling_params
)
else
:
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
self
.
model_config
.
logits_processor_pattern
,
max_tokens
,
self
.
model_config
.
logits_processor_pattern
,
self
.
default_sampling_params
)
self
.
_log_inputs
(
request_id
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
a4113b03
...
...
@@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
is_text_tokens_prompt
)
# yapf: enable
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.utils
import
get_max_tokens
from
vllm.inputs.data
import
(
EmbedsPrompt
,
TokensPrompt
,
is_embeds_prompt
,
is_tokens_prompt
)
from
vllm.logger
import
init_logger
...
...
@@ -160,15 +161,22 @@ class OpenAIServingCompletion(OpenAIServing):
input_length
=
len
(
engine_prompt
[
"prompt_token_ids"
])
else
:
assert_never
(
engine_prompt
)
default_max_tokens
=
self
.
max_model_len
-
input_length
if
self
.
default_sampling_params
is
None
:
self
.
default_sampling_params
=
{}
max_tokens
=
get_max_tokens
(
max_model_len
=
self
.
max_model_len
,
request
=
request
,
input_length
=
input_length
,
default_sampling_params
=
self
.
default_sampling_params
)
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
default_
max_tokens
,
self
.
default_sampling_params
)
max_tokens
,
self
.
default_sampling_params
)
else
:
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
self
.
model_config
.
logits_processor_pattern
,
max_tokens
,
self
.
model_config
.
logits_processor_pattern
,
self
.
default_sampling_params
)
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/entrypoints/utils.py
View file @
a4113b03
...
...
@@ -5,13 +5,17 @@ import argparse
import
asyncio
import
functools
import
os
from
typing
import
Any
,
Optional
import
sys
from
typing
import
Any
,
Optional
,
Union
from
fastapi
import
Request
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
starlette.background
import
BackgroundTask
,
BackgroundTasks
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
...
...
@@ -181,7 +185,6 @@ def _validate_truncation_size(
def
show_filtered_argument_or_group_from_help
(
parser
:
argparse
.
ArgumentParser
,
subcommand_name
:
list
[
str
]):
import
sys
# Only handle --help=<keyword> for the current subcommand.
# Since subparser_init() runs for all subcommands during CLI setup,
...
...
@@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
print
(
f
"
\n
No group or parameter matching '
{
search_keyword
}
'"
)
print
(
"Tip: use `--help=listgroup` to view all groups."
)
sys
.
exit
(
1
)
def
get_max_tokens
(
max_model_len
:
int
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
input_length
:
int
,
default_sampling_params
:
dict
)
->
int
:
max_tokens
=
getattr
(
request
,
"max_completion_tokens"
,
None
)
or
request
.
max_tokens
default_max_tokens
=
max_model_len
-
input_length
max_output_tokens
=
current_platform
.
get_max_output_tokens
(
input_length
)
return
min
(
val
for
val
in
(
default_max_tokens
,
max_tokens
,
max_output_tokens
,
default_sampling_params
.
get
(
"max_tokens"
))
if
val
is
not
None
)
vllm/platforms/interface.py
View file @
a4113b03
...
...
@@ -4,6 +4,7 @@ import enum
import
os
import
platform
import
random
import
sys
from
datetime
import
timedelta
from
platform
import
uname
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
,
Union
...
...
@@ -164,6 +165,9 @@ class Platform:
def
is_out_of_tree
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
OOT
def
get_max_output_tokens
(
self
,
prompt_len
:
int
)
->
int
:
return
sys
.
maxsize
def
is_cuda_alike
(
self
)
->
bool
:
"""Stateless version of [torch.cuda.is_available][]."""
return
self
.
_enum
in
(
PlatformEnum
.
CUDA
,
PlatformEnum
.
ROCM
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment