Unverified Commit 48b01fd4 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[Structured Output] Make the output of structured output example more complete (#22481)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
parent 993d3d12
...@@ -15,6 +15,8 @@ from pydantic import BaseModel ...@@ -15,6 +15,8 @@ from pydantic import BaseModel
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
MAX_TOKENS = 50
# Guided decoding by Choice (list of possible options) # Guided decoding by Choice (list of possible options)
guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
...@@ -23,7 +25,9 @@ prompt_choice = "Classify this sentiment: vLLM is wonderful!" ...@@ -23,7 +25,9 @@ prompt_choice = "Classify this sentiment: vLLM is wonderful!"
# Guided decoding by Regex # Guided decoding by Regex
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams( sampling_params_regex = SamplingParams(
guided_decoding=guided_decoding_params_regex, stop=["\n"] guided_decoding=guided_decoding_params_regex,
stop=["\n"],
max_tokens=MAX_TOKENS,
) )
prompt_regex = ( prompt_regex = (
"Generate an email address for Alan Turing, who works in Enigma." "Generate an email address for Alan Turing, who works in Enigma."
...@@ -48,7 +52,10 @@ class CarDescription(BaseModel): ...@@ -48,7 +52,10 @@ class CarDescription(BaseModel):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema) guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json) sampling_params_json = SamplingParams(
guided_decoding=guided_decoding_params_json,
max_tokens=MAX_TOKENS,
)
prompt_json = ( prompt_json = (
"Generate a JSON with the brand, model and car_type of" "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's" "the most iconic car from the 90's"
...@@ -64,7 +71,10 @@ condition ::= column "= " number ...@@ -64,7 +71,10 @@ condition ::= column "= " number
number ::= "1 " | "2 " number ::= "1 " | "2 "
""" """
guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar) sampling_params_grammar = SamplingParams(
guided_decoding=guided_decoding_params_grammar,
max_tokens=MAX_TOKENS,
)
prompt_grammar = ( prompt_grammar = (
"Generate an SQL query to show the 'username' and 'email'from the 'users' table." "Generate an SQL query to show the 'username' and 'email'from the 'users' table."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment