Unverified Commit 72d5951d authored by Almog Tavor's avatar Almog Tavor Committed by GitHub
Browse files

[Bugfix] Treat generation_config max_tokens as default not ceiling (#34063)


Signed-off-by: default avataralmogtavor <almogtavor@gmail.com>
parent a3205bef
......@@ -526,6 +526,7 @@ class MockModelConfig:
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
override_generation_config: dict[str, Any] = field(default_factory=dict)
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
......@@ -651,12 +652,10 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 10
# Setting server's max_tokens in the generation_config.json
# lower than context_window - prompt_tokens
# Model author's generation_config.json sets max_tokens (auto, no override)
# — should act as fallback only, not ceiling
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"max_tokens": 10 # Setting server-side max_tokens limit
}
mock_model_config.diff_sampling_param = {"max_tokens": 10}
# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=AsyncLLM)
......@@ -680,13 +679,14 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 10
# Test Case 2: Request's max_tokens set higher than server accepts
# Test Case 2: Request's max_tokens set higher than generation_config
# default so request-provided max_tokens takes precedence
req.max_tokens = 15
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
assert mock_engine.generate.call_args.args[1].max_tokens == 15
# Test Case 3: Request's max_tokens set lower than server accepts
req.max_tokens = 5
......@@ -696,12 +696,52 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 5
# User explicitly sets max_tokens via --override-generation-config
# — should act as a ceiling
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {"max_tokens": 10}
mock_model_config.override_generation_config = {"max_new_tokens": 10}
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
# Test Case 3.1: No max_tokens — uses override as default
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
)
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
# Test Case 3.2: Request max_tokens higher — capped by user ceiling from override
req.max_tokens = 15
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
# Test Case 3.3: Request max_tokens lower — respected
req.max_tokens = 5
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5
# Setting server's max_tokens in the generation_config.json
# higher than context_window - prompt_tokens
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"max_tokens": 200 # Setting server-side max_tokens limit
}
mock_model_config.diff_sampling_param = {"max_tokens": 200}
# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=AsyncLLM)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.utils import sanitize_message
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
def test_sanitize_message():
......@@ -8,3 +9,74 @@ def test_sanitize_message():
sanitize_message("<_io.BytesIO object at 0x7a95e299e750>")
== "<_io.BytesIO object>"
)
class TestGetMaxTokens:
"""Tests for get_max_tokens() to ensure generation_config's max_tokens
acts as a default when from model author, and as a ceiling when
explicitly set by the user."""
def test_default_sampling_params_used_when_no_request_max_tokens(self):
"""When user doesn't specify max_tokens, generation_config default
should apply."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=None,
input_length=100,
default_sampling_params={"max_tokens": 2048},
)
assert result == 2048
def test_request_max_tokens_not_capped_by_default_sampling_params(self):
"""When user specifies max_tokens in request, model author's
generation_config max_tokens must NOT cap it (fixes #34005)."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=5000,
input_length=100,
default_sampling_params={"max_tokens": 2048},
)
assert result == 5000
def test_override_max_tokens_caps_request(self):
"""When user explicitly sets max_tokens, it acts as a ceiling."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=5000,
input_length=100,
default_sampling_params={"max_tokens": 2048},
override_max_tokens=2048,
)
assert result == 2048
def test_override_max_tokens_used_as_default(self):
"""When no request max_tokens, override still applies as default."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=None,
input_length=100,
default_sampling_params={"max_tokens": 2048},
override_max_tokens=2048,
)
assert result == 2048
def test_max_model_len_still_caps_output(self):
"""max_model_len - input_length is always the hard ceiling."""
result = get_max_tokens(
max_model_len=3000,
max_tokens=5000,
input_length=100,
default_sampling_params={"max_tokens": 2048},
)
assert result == 2900 # 3000 - 100
def test_request_max_tokens_smaller_than_default(self):
"""When user explicitly requests fewer tokens than gen_config default,
that should be respected."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=512,
input_length=100,
default_sampling_params={"max_tokens": 2048},
)
assert result == 512
......@@ -145,6 +145,12 @@ class OpenAIServingChat(OpenAIServing):
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param()
mc = self.model_config
self.override_max_tokens = (
self.default_sampling_params.get("max_tokens")
if mc.generation_config not in ("auto", "vllm")
else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
)
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params:
......@@ -389,6 +395,7 @@ class OpenAIServingChat(OpenAIServing):
else request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
sampling_params: SamplingParams | BeamSearchParams
......
......@@ -70,6 +70,12 @@ class OpenAIServingCompletion(OpenAIServing):
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param()
mc = self.model_config
self.override_max_tokens = (
self.default_sampling_params.get("max_tokens")
if mc.generation_config not in ("auto", "vllm")
else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
)
async def render_completion_request(
self,
......@@ -164,6 +170,7 @@ class OpenAIServingCompletion(OpenAIServing):
request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
sampling_params: SamplingParams | BeamSearchParams
......
......@@ -1174,6 +1174,7 @@ class OpenAIServing:
context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore
)
# OPTIMIZATION
......
......@@ -229,6 +229,12 @@ class OpenAIServingResponses(OpenAIServing):
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param()
mc = self.model_config
self.override_max_tokens = (
self.default_sampling_params.get("max_tokens")
if mc.generation_config not in ("auto", "vllm")
else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
)
# If False (default), the "store" option is (silently) ignored and the
# response is not stored. If True, the response is stored in memory.
......@@ -446,6 +452,7 @@ class OpenAIServingResponses(OpenAIServing):
request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
sampling_params = request.to_sampling_params(
......
......@@ -177,17 +177,23 @@ def get_max_tokens(
max_tokens: int | None,
input_length: int,
default_sampling_params: dict,
override_max_tokens: int | None = None,
) -> int:
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
model_max_tokens = max_model_len - input_length
platform_max_tokens = current_platform.get_max_output_tokens(input_length)
fallback_max_tokens = (
max_tokens
if max_tokens is not None
else default_sampling_params.get("max_tokens")
)
return min(
val
for val in (
default_max_tokens,
max_tokens,
max_output_tokens,
default_sampling_params.get("max_tokens"),
model_max_tokens,
fallback_max_tokens,
override_max_tokens,
platform_max_tokens,
)
if val is not None
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment