Unverified Commit 1bff42c4 authored by Reid's avatar Reid Committed by GitHub
Browse files

[Misc] refactor Structured Outputs example (#16322)


Signed-off-by: default avatarreidliu41 <reid201711@gmail.com>
Co-authored-by: default avatarreidliu41 <reid201711@gmail.com>
parent cb391d85
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of guided decoding
to generate structured outputs using vLLM. It shows how to apply
different guided decoding techniques such as Choice, Regex, JSON schema,
and Grammar to produce structured and formatted results
based on specific prompts.
"""
from enum import Enum from enum import Enum
...@@ -7,26 +14,21 @@ from pydantic import BaseModel ...@@ -7,26 +14,21 @@ from pydantic import BaseModel
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
# Guided decoding by Choice (list of possible options) # Guided decoding by Choice (list of possible options)
guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"]) guided_decoding_params_choice = GuidedDecodingParams(
sampling_params = SamplingParams(guided_decoding=guided_decoding_params) choice=["Positive", "Negative"])
outputs = llm.generate( sampling_params_choice = SamplingParams(
prompts="Classify this sentiment: vLLM is wonderful!", guided_decoding=guided_decoding_params_choice)
sampling_params=sampling_params, prompt_choice = "Classify this sentiment: vLLM is wonderful!"
)
print(outputs[0].outputs[0].text)
# Guided decoding by Regex # Guided decoding by Regex
guided_decoding_params = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params = SamplingParams(guided_decoding=guided_decoding_params, sampling_params_regex = SamplingParams(
stop=["\n"]) guided_decoding=guided_decoding_params_regex, stop=["\n"])
prompt = ("Generate an email address for Alan Turing, who works in Enigma." prompt_regex = (
"End in .com and new line. Example result:" "Generate an email address for Alan Turing, who works in Enigma."
"alan.turing@enigma.com\n") "End in .com and new line. Example result:"
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) "alan.turing@enigma.com\n")
print(outputs[0].outputs[0].text)
# Guided decoding by JSON using Pydantic schema # Guided decoding by JSON using Pydantic schema
...@@ -44,16 +46,11 @@ class CarDescription(BaseModel): ...@@ -44,16 +46,11 @@ class CarDescription(BaseModel):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
guided_decoding_params = GuidedDecodingParams(json=json_schema) sampling_params_json = SamplingParams(
sampling_params = SamplingParams(guided_decoding=guided_decoding_params) guided_decoding=guided_decoding_params_json)
prompt = ("Generate a JSON with the brand, model and car_type of" prompt_json = ("Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's") "the most iconic car from the 90's")
outputs = llm.generate(
prompts=prompt,
sampling_params=sampling_params,
)
print(outputs[0].outputs[0].text)
# Guided decoding by Grammar # Guided decoding by Grammar
simplified_sql_grammar = """ simplified_sql_grammar = """
...@@ -64,12 +61,39 @@ table ::= "table_1 " | "table_2 " ...@@ -64,12 +61,39 @@ table ::= "table_1 " | "table_2 "
condition ::= column "= " number condition ::= column "= " number
number ::= "1 " | "2 " number ::= "1 " | "2 "
""" """
guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar) guided_decoding_params_grammar = GuidedDecodingParams(
sampling_params = SamplingParams(guided_decoding=guided_decoding_params) grammar=simplified_sql_grammar)
prompt = ("Generate an SQL query to show the 'username' and 'email'" sampling_params_grammar = SamplingParams(
"from the 'users' table.") guided_decoding=guided_decoding_params_grammar)
outputs = llm.generate( prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'"
prompts=prompt, "from the 'users' table.")
sampling_params=sampling_params,
)
print(outputs[0].outputs[0].text) def format_output(title: str, output: str):
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
return outputs[0].outputs[0].text
def main():
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
format_output("Guided decoding by Choice", choice_output)
regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
format_output("Guided decoding by Regex", regex_output)
json_output = generate_output(prompt_json, sampling_params_json, llm)
format_output("Guided decoding by JSON", json_output)
grammar_output = generate_output(prompt_grammar, sampling_params_grammar,
llm)
format_output("Guided decoding by Grammar", grammar_output)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment