Unverified Commit 98d01d3c authored by Guillaume Calmettes's avatar Guillaume Calmettes Committed by GitHub
Browse files

[Bugfix][Frontend] respect provided default guided decoding backend (#15476)


Signed-off-by: default avatarGuillaume Calmettes <gcalmettes@scaleway.com>
parent d55244df
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for the SamplingParams class. """Tests for the SamplingParams class.
""" """
import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
MODEL_NAME = "Qwen/Qwen1.5-7B"
def test_max_tokens_none(): def test_max_tokens_none():
...@@ -9,6 +16,74 @@ def test_max_tokens_none(): ...@@ -9,6 +16,74 @@ def test_max_tokens_none():
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
if __name__ == "__main__": @pytest.fixture(scope="module")
import pytest def model_config():
pytest.main([__file__]) return ModelConfig(
MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)
@pytest.fixture(scope="module")
def default_max_tokens():
return 4096
def test_sampling_params_from_request_with_no_guided_decoding_backend(
model_config, default_max_tokens):
# guided_decoding_backend is not present at request level
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# we do not expect any backend to be present and the default
# guided_decoding_backend at engine level will be used.
assert sampling_params.guided_decoding.backend is None
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
[("xgrammar", "xgrammar"),
("lm-format-enforcer", "lm-format-enforcer"),
("outlines", "outlines")])
def test_sampling_params_from_request_with_guided_decoding_backend(
request_level_guided_decoding_backend: str, expected: str,
model_config, default_max_tokens):
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
MODEL_NAME,
'response_format': {
'type': 'json_object',
},
'guided_decoding_backend':
request_level_guided_decoding_backend,
})
sampling_params = request.to_sampling_params(
default_max_tokens,
model_config.logits_processor_pattern,
)
# backend correctly identified in resulting sampling_params
assert sampling_params.guided_decoding.backend == expected
...@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -476,8 +476,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema = self.response_format.json_schema json_schema = self.response_format.json_schema
assert json_schema is not None assert json_schema is not None
self.guided_json = json_schema.json_schema self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"
guided_decoding = GuidedDecodingParams.from_optional( guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json, json=self._get_guided_json_from_tool() or self.guided_json,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment