Unverified Commit 2f666b73 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

fix: Properly forward sampling params from the request (#5797)

parent 408d7868
......@@ -14,7 +14,7 @@
# limitations under the License.
import asyncio
import copy
import dataclasses
import logging
import os
from contextlib import asynccontextmanager
......@@ -615,13 +615,9 @@ class HandlerBase:
num_output_tokens_so_far = 0
sampling_params = copy.deepcopy(self.default_sampling_params)
for key, value in request["sampling_options"].items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
sampling_params = self._override_sampling_params(
self.default_sampling_params, request
)
# Additional sampling params in output options
output_options = request.get("output_options", {})
......@@ -818,3 +814,16 @@ class HandlerBase:
# Initiate graceful shutdown
await self._initiate_shutdown(e)
@staticmethod
def _override_sampling_params(sampling_params, request: dict) -> SamplingParams:
overrides = {
key: value
for key, value in request["sampling_options"].items()
if value is not None
}
# NOTE: using `dataclasses.replace` has several benefits over a `setattr` based approach:
# 1. it catches unsupported fields / attributes.
# 2. it executes the class's `__post_init__`, which may contain helpful validation logic.
return dataclasses.replace(sampling_params, **overrides)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from unittest import mock
import pytest
from dynamo.trtllm.request_handlers.handler_base import HandlerBase
pytestmark = [
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.pre_merge,
]
@dataclass
class MockSamplingParams:
"""Mock sampling params object for testing."""
temperature: float = 1.0
top_p: float = 1.0
top_k: int = 50
repetition_penalty: float = 1.0
seed: int | None = None
ignore_eos: bool = False
def __post_init__(self):
"""Called after dataclass initialization (including via replace())."""
pass
class TestOverrideSamplingParams:
"""Tests for _override_sampling_params method.
The key bug fix being tested: using `if value is None` instead of `if not value`
ensures that falsy values like 0, False, and "" are correctly applied.
"""
def test_falsy_values_are_applied(self):
"""Test that falsy values (0, False) are correctly set.
This is the main regression test for the bug fix. Previously, using
`if not value` would skip setting values like 0 or False.
"""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"temperature": 0, # Falsy but valid - should be set
"top_k": 0, # Falsy but valid - should be set
"ignore_eos": False, # Falsy but valid - should be set
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.temperature == 0
assert result.top_k == 0
assert result.ignore_eos is False
def test_none_values_are_skipped(self):
"""Test that None values do not override existing params."""
sampling_params = MockSamplingParams()
original_temperature = sampling_params.temperature
original_top_p = sampling_params.top_p
request = {
"sampling_options": {
"temperature": None,
"top_p": None,
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.temperature == original_temperature
assert result.top_p == original_top_p
def test_truthy_values_are_applied(self):
"""Test that normal truthy values are correctly set."""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"seed": 42,
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.temperature == 0.7
assert result.top_p == 0.9
assert result.top_k == 40
assert result.seed == 42
def test_unknown_attributes_raise_error(self):
"""Test that unknown attributes raise a TypeError.
dataclasses.replace() does not accept unknown field names.
"""
sampling_params = MockSamplingParams()
request = {
"sampling_options": {
"nonexistent_param": 123,
}
}
with pytest.raises(TypeError):
HandlerBase._override_sampling_params(sampling_params, request)
def test_mixed_values(self):
"""Test a mix of None, falsy, and truthy values."""
sampling_params = MockSamplingParams()
original_top_p = sampling_params.top_p
request = {
"sampling_options": {
"temperature": 0, # Falsy - should be set
"top_p": None, # None - should be skipped
"top_k": 100, # Truthy - should be set
"seed": 0, # Falsy - should be set
}
}
result = HandlerBase._override_sampling_params(sampling_params, request)
assert result.temperature == 0
assert result.top_p == original_top_p # Unchanged
assert result.top_k == 100
assert result.seed == 0
def test_unsupported_fields_raise(self):
sampling_params = MockSamplingParams()
request = {"sampling_options": {"non_existent_param": 123}}
with pytest.raises(TypeError, match="unexpected keyword argument"):
_ = HandlerBase._override_sampling_params(sampling_params, request)
def test_post_init_called_when_overriding(self):
# This allows us to check that potential validation logic in `__post_init__` is run when
# overriding the sampling params with what comes from the requests.
sampling_params = MockSamplingParams()
request = {"sampling_options": {"temperature": 0.5}}
with mock.patch.object(MockSamplingParams, "__post_init__") as mock_post_init:
HandlerBase._override_sampling_params(sampling_params, request)
mock_post_init.assert_called_once()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment