"docs/features/vscode:/vscode.git/clone" did not exist on "07742cc27f043e42ed31a5825f4346d88290367c"
Unverified Commit 3e4b480b authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

fix: guided decoding params handling in vLLM (#4770)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: default avatarKaren Chung <karenc@nvidia.com>
parent 197e0227
...@@ -12,7 +12,7 @@ from typing import Any, AsyncGenerator, Dict, Final ...@@ -12,7 +12,7 @@ from typing import Any, AsyncGenerator, Dict, Final
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.llm import ( from dynamo.llm import (
...@@ -82,8 +82,22 @@ def build_sampling_params( ...@@ -82,8 +82,22 @@ def build_sampling_params(
sampling_params = SamplingParams(**default_sampling_params) sampling_params = SamplingParams(**default_sampling_params)
sampling_params.detokenize = False sampling_params.detokenize = False
# Apply sampling_options # Handle guided_decoding - convert to StructuredOutputsParams
guided_decoding = request["sampling_options"].get("guided_decoding")
if guided_decoding is not None and isinstance(guided_decoding, dict):
sampling_params.structured_outputs = StructuredOutputsParams(
json=guided_decoding.get("json"),
regex=guided_decoding.get("regex"),
choice=guided_decoding.get("choice"),
grammar=guided_decoding.get("grammar"),
whitespace_pattern=guided_decoding.get("whitespace_pattern"),
)
# Apply remaining sampling_options
for key, value in request["sampling_options"].items(): for key, value in request["sampling_options"].items():
# Skip guided_decoding - already handled above
if key == "guided_decoding":
continue
if value is not None and hasattr(sampling_params, key): if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value) setattr(sampling_params, key, value)
......
...@@ -454,6 +454,66 @@ vllm_configs = { ...@@ -454,6 +454,66 @@ vllm_configs = {
completion_payload_default(), completion_payload_default(),
], ],
), ),
"guided_decoding_json": VLLMConfig(
name="guided_decoding_json",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload(
"Generate a person with name and age",
repeat_count=1,
expected_response=['"name"', '"age"'],
temperature=0.0,
max_tokens=100,
extra_body={
"guided_json": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
"required": ["name", "age"],
}
},
)
],
),
"guided_decoding_regex": VLLMConfig(
name="guided_decoding_regex",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload(
"Generate a color name (red, blue, or green)",
repeat_count=1,
expected_response=["red", "blue", "green"],
temperature=0.0,
max_tokens=20,
extra_body={"guided_regex": r"(red|blue|green)"},
)
],
),
"guided_decoding_choice": VLLMConfig(
name="guided_decoding_choice",
directory=vllm_dir,
script_name="agg.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
model="Qwen/Qwen3-0.6B",
request_payloads=[
chat_payload(
"Generate a color name (red, blue, or green)",
repeat_count=1,
expected_response=["red", "blue", "green"],
temperature=0.0,
max_tokens=20,
extra_body={"guided_choice": ["red", "blue", "green"]},
)
],
),
} }
......
...@@ -134,6 +134,7 @@ def chat_payload( ...@@ -134,6 +134,7 @@ def chat_payload(
max_tokens: int = 300, max_tokens: int = 300,
temperature: Optional[float] = None, temperature: Optional[float] = None,
stream: bool = False, stream: bool = False,
extra_body: Optional[Dict[str, Any]] = None,
) -> ChatPayload: ) -> ChatPayload:
body: Dict[str, Any] = { body: Dict[str, Any] = {
"messages": [ "messages": [
...@@ -148,6 +149,9 @@ def chat_payload( ...@@ -148,6 +149,9 @@ def chat_payload(
if temperature is not None: if temperature is not None:
body["temperature"] = temperature body["temperature"] = temperature
if extra_body:
body.update(extra_body)
return ChatPayload( return ChatPayload(
body=body, body=body,
repeat_count=repeat_count, repeat_count=repeat_count,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment