Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
58405177
Unverified
Commit
58405177
authored
Nov 19, 2025
by
bin.pan
Committed by
GitHub
Nov 18, 2025
Browse files
feat: Support a dynamic default max_tokens for VLLM backend (#4156)
Signed-off-by:
bin
<
bin.pan@daocloud.io
>
parent
4f2cbec0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
8 deletions
+48
-8
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+42
-7
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+6
-1
No files found.
components/src/dynamo/vllm/handlers.py
View file @
58405177
...
@@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
...
@@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
def
build_sampling_params
(
def
build_sampling_params
(
request
:
Dict
[
str
,
Any
],
default_sampling_params
:
Dict
[
str
,
Any
]
request
:
Dict
[
str
,
Any
],
default_sampling_params
:
Dict
[
str
,
Any
],
model_max_len
:
int
|
None
=
None
,
)
->
SamplingParams
:
)
->
SamplingParams
:
"""
"""
Build SamplingParams from a PreprocessedRequest.
Build SamplingParams from a PreprocessedRequest.
...
@@ -59,6 +61,15 @@ def build_sampling_params(
...
@@ -59,6 +61,15 @@ def build_sampling_params(
continue
continue
setattr
(
sampling_params
,
key
,
value
)
setattr
(
sampling_params
,
key
,
value
)
# If max_tokens wasn't provided (None or missing), compute a dynamic default
provided_max_tokens
=
request
.
get
(
"stop_conditions"
,
{}).
get
(
"max_tokens"
,
None
)
token_ids
=
request
.
get
(
"token_ids"
,
[])
input_length
=
len
(
token_ids
)
if
model_max_len
is
not
None
and
(
provided_max_tokens
is
None
):
# Ensure at least 1 token generation by default when possible
dynamic_default
=
max
(
1
,
model_max_len
-
input_length
)
sampling_params
.
max_tokens
=
dynamic_default
return
sampling_params
return
sampling_params
...
@@ -67,7 +78,14 @@ class BaseWorkerHandler(ABC):
...
@@ -67,7 +78,14 @@ class BaseWorkerHandler(ABC):
Request handler for the generate and clear_kv_blocks endpoints.
Request handler for the generate and clear_kv_blocks endpoints.
"""
"""
def
__init__
(
self
,
runtime
,
component
,
engine
,
default_sampling_params
):
def
__init__
(
self
,
runtime
,
component
,
engine
,
default_sampling_params
,
model_max_len
:
int
|
None
=
None
,
):
self
.
runtime
=
runtime
self
.
runtime
=
runtime
self
.
component
=
component
self
.
component
=
component
self
.
engine_client
=
engine
self
.
engine_client
=
engine
...
@@ -76,6 +94,7 @@ class BaseWorkerHandler(ABC):
...
@@ -76,6 +94,7 @@ class BaseWorkerHandler(ABC):
self
.
engine_monitor
=
VllmEngineMonitor
(
runtime
,
engine
)
self
.
engine_monitor
=
VllmEngineMonitor
(
runtime
,
engine
)
self
.
image_loader
=
ImageLoader
()
self
.
image_loader
=
ImageLoader
()
self
.
temp_dirs
:
list
[
tempfile
.
TemporaryDirectory
]
=
[]
self
.
temp_dirs
:
list
[
tempfile
.
TemporaryDirectory
]
=
[]
self
.
model_max_len
=
model_max_len
@
abstractmethod
@
abstractmethod
async
def
generate
(
self
,
request
,
context
)
->
AsyncGenerator
[
dict
,
None
]:
async
def
generate
(
self
,
request
,
context
)
->
AsyncGenerator
[
dict
,
None
]:
...
@@ -251,8 +270,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -251,8 +270,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
component
,
component
,
engine
,
engine
,
default_sampling_params
,
default_sampling_params
,
model_max_len
:
int
|
None
=
None
,
):
):
super
().
__init__
(
runtime
,
component
,
engine
,
default_sampling_params
)
super
().
__init__
(
runtime
,
component
,
engine
,
default_sampling_params
,
model_max_len
)
async
def
generate
(
self
,
request
,
context
):
async
def
generate
(
self
,
request
,
context
):
# Use context ID for request tracking and correlation
# Use context ID for request tracking and correlation
...
@@ -267,7 +289,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -267,7 +289,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
)
)
# Build sampling params from request
# Build sampling params from request
sampling_params
=
build_sampling_params
(
request
,
self
.
default_sampling_params
)
sampling_params
=
build_sampling_params
(
request
,
self
.
default_sampling_params
,
self
.
model_max_len
)
prefill_result
=
request
.
get
(
"prefill_result"
)
prefill_result
=
request
.
get
(
"prefill_result"
)
if
prefill_result
and
isinstance
(
prefill_result
,
dict
):
if
prefill_result
and
isinstance
(
prefill_result
,
dict
):
...
@@ -308,8 +332,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -308,8 +332,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
class
PrefillWorkerHandler
(
BaseWorkerHandler
):
class
PrefillWorkerHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
runtime
,
component
,
engine
,
default_sampling_params
):
def
__init__
(
super
().
__init__
(
runtime
,
component
,
engine
,
default_sampling_params
)
self
,
runtime
,
component
,
engine
,
default_sampling_params
,
model_max_len
:
int
|
None
=
None
,
):
super
().
__init__
(
runtime
,
component
,
engine
,
default_sampling_params
,
model_max_len
)
async
def
generate
(
self
,
request
,
context
):
async
def
generate
(
self
,
request
,
context
):
# Use context ID for request tracking and correlation with decode phase
# Use context ID for request tracking and correlation with decode phase
...
@@ -325,7 +358,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
...
@@ -325,7 +358,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
)
)
# Build sampling params from request using shared utility
# Build sampling params from request using shared utility
sampling_params
=
build_sampling_params
(
request
,
self
.
default_sampling_params
)
sampling_params
=
build_sampling_params
(
request
,
self
.
default_sampling_params
,
self
.
model_max_len
)
# Configure for prefill-only mode with remote decode
# Configure for prefill-only mode with remote decode
if
sampling_params
.
extra_args
is
None
:
if
sampling_params
.
extra_args
is
None
:
...
...
components/src/dynamo/vllm/main.py
View file @
58405177
...
@@ -339,7 +339,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
...
@@ -339,7 +339,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
)
=
setup_vllm_engine
(
config
)
)
=
setup_vllm_engine
(
config
)
handler
=
PrefillWorkerHandler
(
handler
=
PrefillWorkerHandler
(
runtime
,
component
,
engine_client
,
default_sampling_params
runtime
,
component
,
engine_client
,
default_sampling_params
,
getattr
(
getattr
(
vllm_config
,
"model_config"
,
None
),
"max_model_len"
,
None
),
)
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
...
@@ -450,6 +454,7 @@ async def init(runtime: DistributedRuntime, config: Config):
...
@@ -450,6 +454,7 @@ async def init(runtime: DistributedRuntime, config: Config):
component
,
component
,
engine_client
,
engine_client
,
default_sampling_params
,
default_sampling_params
,
getattr
(
getattr
(
vllm_config
,
"model_config"
,
None
),
"max_model_len"
,
None
),
)
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment