Unverified Commit c84b519c authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix negative max_tokens when input prompt is too long (#36789)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 741ecf06
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.entrypoints.utils import get_max_tokens, sanitize_message from vllm.entrypoints.utils import get_max_tokens, sanitize_message
...@@ -80,3 +82,15 @@ class TestGetMaxTokens: ...@@ -80,3 +82,15 @@ class TestGetMaxTokens:
default_sampling_params={"max_tokens": 2048}, default_sampling_params={"max_tokens": 2048},
) )
assert result == 512 assert result == 512
def test_input_length_exceeds_max_model_len(self):
with pytest.raises(
ValueError,
match="Input length .* exceeds model's maximum context length .*",
):
get_max_tokens(
max_model_len=100,
max_tokens=50,
input_length=150,
default_sampling_params={"max_tokens": 2048},
)
...@@ -178,6 +178,11 @@ def get_max_tokens( ...@@ -178,6 +178,11 @@ def get_max_tokens(
default_sampling_params: dict, default_sampling_params: dict,
override_max_tokens: int | None = None, override_max_tokens: int | None = None,
) -> int: ) -> int:
if max_model_len < input_length:
raise ValueError(
f"Input length ({input_length}) exceeds model's maximum "
f"context length ({max_model_len})."
)
model_max_tokens = max_model_len - input_length model_max_tokens = max_model_len - input_length
platform_max_tokens = current_platform.get_max_output_tokens(input_length) platform_max_tokens = current_platform.get_max_output_tokens(input_length)
fallback_max_tokens = ( fallback_max_tokens = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment