Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
2f666b73
Unverified
Commit
2f666b73
authored
Jan 30, 2026
by
William Zhang
Committed by
GitHub
Jan 30, 2026
Browse files
fix: Properly forward sampling params from the request (#5797)
parent
408d7868
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
168 additions
and
8 deletions
+168
-8
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+17
-8
components/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
...nents/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
+151
-0
No files found.
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
2f666b73
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
import
asyncio
import
copy
import
dataclasses
import
logging
import
os
from
contextlib
import
asynccontextmanager
...
...
@@ -615,13 +615,9 @@ class HandlerBase:
num_output_tokens_so_far
=
0
sampling_params
=
copy
.
deepcopy
(
self
.
default_sampling_params
)
for
key
,
value
in
request
[
"sampling_options"
].
items
():
if
not
value
:
continue
if
hasattr
(
sampling_params
,
key
):
setattr
(
sampling_params
,
key
,
value
)
sampling_params
=
self
.
_override_sampling_params
(
self
.
default_sampling_params
,
request
)
# Additional sampling params in output options
output_options
=
request
.
get
(
"output_options"
,
{})
...
...
@@ -818,3 +814,16 @@ class HandlerBase:
# Initiate graceful shutdown
await
self
.
_initiate_shutdown
(
e
)
@
staticmethod
def
_override_sampling_params
(
sampling_params
,
request
:
dict
)
->
SamplingParams
:
overrides
=
{
key
:
value
for
key
,
value
in
request
[
"sampling_options"
].
items
()
if
value
is
not
None
}
# NOTE: using `dataclasses.replace` has several benefits over a `setattr` based approach:
# 1. it catches unsupported fields / attributes.
# 2. it executes the class's `__post_init__`, which may contain helpful validation logic.
return
dataclasses
.
replace
(
sampling_params
,
**
overrides
)
components/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
0 → 100644
View file @
2f666b73
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
unittest
import
mock
import
pytest
from
dynamo.trtllm.request_handlers.handler_base
import
HandlerBase
pytestmark
=
[
pytest
.
mark
.
unit
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
pre_merge
,
]
@
dataclass
class
MockSamplingParams
:
"""Mock sampling params object for testing."""
temperature
:
float
=
1.0
top_p
:
float
=
1.0
top_k
:
int
=
50
repetition_penalty
:
float
=
1.0
seed
:
int
|
None
=
None
ignore_eos
:
bool
=
False
def
__post_init__
(
self
):
"""Called after dataclass initialization (including via replace())."""
pass
class
TestOverrideSamplingParams
:
"""Tests for _override_sampling_params method.
The key bug fix being tested: using `if value is None` instead of `if not value`
ensures that falsy values like 0, False, and "" are correctly applied.
"""
def
test_falsy_values_are_applied
(
self
):
"""Test that falsy values (0, False) are correctly set.
This is the main regression test for the bug fix. Previously, using
`if not value` would skip setting values like 0 or False.
"""
sampling_params
=
MockSamplingParams
()
request
=
{
"sampling_options"
:
{
"temperature"
:
0
,
# Falsy but valid - should be set
"top_k"
:
0
,
# Falsy but valid - should be set
"ignore_eos"
:
False
,
# Falsy but valid - should be set
}
}
result
=
HandlerBase
.
_override_sampling_params
(
sampling_params
,
request
)
assert
result
.
temperature
==
0
assert
result
.
top_k
==
0
assert
result
.
ignore_eos
is
False
def
test_none_values_are_skipped
(
self
):
"""Test that None values do not override existing params."""
sampling_params
=
MockSamplingParams
()
original_temperature
=
sampling_params
.
temperature
original_top_p
=
sampling_params
.
top_p
request
=
{
"sampling_options"
:
{
"temperature"
:
None
,
"top_p"
:
None
,
}
}
result
=
HandlerBase
.
_override_sampling_params
(
sampling_params
,
request
)
assert
result
.
temperature
==
original_temperature
assert
result
.
top_p
==
original_top_p
def
test_truthy_values_are_applied
(
self
):
"""Test that normal truthy values are correctly set."""
sampling_params
=
MockSamplingParams
()
request
=
{
"sampling_options"
:
{
"temperature"
:
0.7
,
"top_p"
:
0.9
,
"top_k"
:
40
,
"seed"
:
42
,
}
}
result
=
HandlerBase
.
_override_sampling_params
(
sampling_params
,
request
)
assert
result
.
temperature
==
0.7
assert
result
.
top_p
==
0.9
assert
result
.
top_k
==
40
assert
result
.
seed
==
42
def
test_unknown_attributes_raise_error
(
self
):
"""Test that unknown attributes raise a TypeError.
dataclasses.replace() does not accept unknown field names.
"""
sampling_params
=
MockSamplingParams
()
request
=
{
"sampling_options"
:
{
"nonexistent_param"
:
123
,
}
}
with
pytest
.
raises
(
TypeError
):
HandlerBase
.
_override_sampling_params
(
sampling_params
,
request
)
def
test_mixed_values
(
self
):
"""Test a mix of None, falsy, and truthy values."""
sampling_params
=
MockSamplingParams
()
original_top_p
=
sampling_params
.
top_p
request
=
{
"sampling_options"
:
{
"temperature"
:
0
,
# Falsy - should be set
"top_p"
:
None
,
# None - should be skipped
"top_k"
:
100
,
# Truthy - should be set
"seed"
:
0
,
# Falsy - should be set
}
}
result
=
HandlerBase
.
_override_sampling_params
(
sampling_params
,
request
)
assert
result
.
temperature
==
0
assert
result
.
top_p
==
original_top_p
# Unchanged
assert
result
.
top_k
==
100
assert
result
.
seed
==
0
def
test_unsupported_fields_raise
(
self
):
sampling_params
=
MockSamplingParams
()
request
=
{
"sampling_options"
:
{
"non_existent_param"
:
123
}}
with
pytest
.
raises
(
TypeError
,
match
=
"unexpected keyword argument"
):
_
=
HandlerBase
.
_override_sampling_params
(
sampling_params
,
request
)
def
test_post_init_called_when_overriding
(
self
):
# This allows us to check that potential validation logic in `__post_init__` is run when
# overriding the sampling params with what comes from the requests.
sampling_params
=
MockSamplingParams
()
request
=
{
"sampling_options"
:
{
"temperature"
:
0.5
}}
with
mock
.
patch
.
object
(
MockSamplingParams
,
"__post_init__"
)
as
mock_post_init
:
HandlerBase
.
_override_sampling_params
(
sampling_params
,
request
)
mock_post_init
.
assert_called_once
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment