Unverified Commit 3e4b480b authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

fix: guided decoding params handling in vLLM (#4770)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: default avatarKaren Chung <karenc@nvidia.com>
parent 197e0227
......@@ -12,7 +12,7 @@ from typing import Any, AsyncGenerator, Dict, Final
from vllm.inputs import TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.llm import (
......@@ -82,8 +82,22 @@ def build_sampling_params(
sampling_params = SamplingParams(**default_sampling_params)
sampling_params.detokenize = False
# Apply sampling_options
# Handle guided_decoding - convert to StructuredOutputsParams
guided_decoding = request["sampling_options"].get("guided_decoding")
if guided_decoding is not None and isinstance(guided_decoding, dict):
sampling_params.structured_outputs = StructuredOutputsParams(
json=guided_decoding.get("json"),
regex=guided_decoding.get("regex"),
choice=guided_decoding.get("choice"),
grammar=guided_decoding.get("grammar"),
whitespace_pattern=guided_decoding.get("whitespace_pattern"),
)
# Apply remaining sampling_options
for key, value in request["sampling_options"].items():
# Skip guided_decoding - already handled above
if key == "guided_decoding":
continue
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
......
......@@ -454,6 +454,66 @@ vllm_configs = {
completion_payload_default(),
],
),
"guided_decoding_json": VLLMConfig(
name="guided_decoding_json",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload(
"Generate a person with name and age",
repeat_count=1,
expected_response=['"name"', '"age"'],
temperature=0.0,
max_tokens=100,
extra_body={
"guided_json": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
"required": ["name", "age"],
}
},
)
],
),
"guided_decoding_regex": VLLMConfig(
name="guided_decoding_regex",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload(
"Generate a color name (red, blue, or green)",
repeat_count=1,
expected_response=["red", "blue", "green"],
temperature=0.0,
max_tokens=20,
extra_body={"guided_regex": r"(red|blue|green)"},
)
],
),
"guided_decoding_choice": VLLMConfig(
name="guided_decoding_choice",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload(
"Generate a color name (red, blue, or green)",
repeat_count=1,
expected_response=["red", "blue", "green"],
temperature=0.0,
max_tokens=20,
extra_body={"guided_choice": ["red", "blue", "green"]},
)
],
),
}
......
......@@ -134,6 +134,7 @@ def chat_payload(
max_tokens: int = 300,
temperature: Optional[float] = None,
stream: bool = False,
extra_body: Optional[Dict[str, Any]] = None,
) -> ChatPayload:
body: Dict[str, Any] = {
"messages": [
......@@ -148,6 +149,9 @@ def chat_payload(
if temperature is not None:
body["temperature"] = temperature
if extra_body:
body.update(extra_body)
return ChatPayload(
body=body,
repeat_count=repeat_count,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment