Unverified Commit 452db508 authored by mlmz's avatar mlmz Committed by GitHub
Browse files

Constraint Decoding: Set xgrammar as the default grammar backend (#4386)

parent d1112d85
...@@ -15,15 +15,15 @@ ...@@ -15,15 +15,15 @@
"\n", "\n",
"SGLang supports two grammar backends:\n", "SGLang supports two grammar backends:\n",
"\n", "\n",
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", "- [Outlines](https://github.com/dottxt-ai/outlines): Supports JSON schema and regular expression constraints.\n",
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n", "- [XGrammar](https://github.com/mlc-ai/xgrammar)(default): Supports JSON schema, regular expression, and EBNF constraints.\n",
"- [Llguidance](https://github.com/guidance-ai/llguidance): Supports JSON schema, regular expression, and EBNF constraints.\n", "- [Llguidance](https://github.com/guidance-ai/llguidance): Supports JSON schema, regular expression, and EBNF constraints.\n",
"\n", "\n",
"We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", "We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
"\n", "\n",
"To use Xgrammar, simply add `--grammar-backend xgrammar` when launching the server.\n", "To use Outlines, simply add `--grammar-backend outlines` when launching the server.\n",
"To use llguidance, add `--grammar-backend llguidance` when launching the server.\n", "To use llguidance, add `--grammar-backend llguidance` when launching the server.\n",
"If no backend is specified, Outlines will be used as the default.\n", "If no backend is specified, XGrammar will be used as the default.\n",
"\n", "\n",
"For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n" "For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n"
] ]
...@@ -56,7 +56,7 @@ ...@@ -56,7 +56,7 @@
"\n", "\n",
"\n", "\n",
"server_process, port = launch_server_cmd(\n", "server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0 --grammar-backend xgrammar\"\n", " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(f\"http://localhost:{port}\")\n", "wait_for_server(f\"http://localhost:{port}\")\n",
...@@ -229,6 +229,131 @@ ...@@ -229,6 +229,131 @@
"print_highlight(response.choices[0].message.content)" "print_highlight(response.choices[0].message.content)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structural Tag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tool_get_current_weather = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
" },\n",
" \"state\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
" \" in, e.g. 'CA' which would mean 'California'\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The unit to fetch the temperature in\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" },\n",
" },\n",
" \"required\": [\"city\", \"state\", \"unit\"],\n",
" },\n",
" },\n",
"}\n",
"\n",
"tool_get_current_date = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_date\",\n",
" \"description\": \"Get the current date and time for a given timezone\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"timezone\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The timezone to fetch the current date and time for, e.g. 'America/New_York'\",\n",
" }\n",
" },\n",
" \"required\": [\"timezone\"],\n",
" },\n",
" },\n",
"}\n",
"\n",
"schema_get_current_weather = tool_get_current_weather[\"function\"][\"parameters\"]\n",
"schema_get_current_date = tool_get_current_date[\"function\"][\"parameters\"]\n",
"\n",
"\n",
"def get_messages():\n",
" return [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": f\"\"\"\n",
"# Tool Instructions\n",
"- Always execute python code in messages that you share.\n",
"- When looking for real time information use relevant functions if available else fallback to brave_search\n",
"You have access to the following functions:\n",
"Use the function 'get_current_weather' to: Get the current weather in a given location\n",
"{tool_get_current_weather[\"function\"]}\n",
"Use the function 'get_current_date' to: Get the current date and time for a given timezone\n",
"{tool_get_current_date[\"function\"]}\n",
"If a you choose to call a function ONLY reply in the following format:\n",
"<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}\n",
"where\n",
"start_tag => `<function`\n",
"parameters => a JSON dict with the function argument name as key and function argument value as value.\n",
"end_tag => `</function>`\n",
"Here is an example,\n",
"<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>\n",
"Reminder:\n",
"- Function calls MUST follow the specified format\n",
"- Required parameters MUST be specified\n",
"- Only call one function at a time\n",
"- Put the entire function call reply on one line\n",
"- Always add your sources when using search results to answer the user query\n",
"You are a helpful assistant.\"\"\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" },\n",
" ]\n",
"\n",
"\n",
"messages = get_messages()\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=messages,\n",
" response_format={\n",
" \"type\": \"structural_tag\",\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" },\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
...@@ -341,8 +466,6 @@ ...@@ -341,8 +466,6 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import requests\n",
"\n",
"response = requests.post(\n", "response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
...@@ -394,6 +517,57 @@ ...@@ -394,6 +517,57 @@
"print_highlight(response.json())" "print_highlight(response.json())"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structural Tag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"# generate an answer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
"\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"payload = {\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" }\n",
" )\n",
" },\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n",
"print_highlight(response.json())"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -575,6 +749,56 @@ ...@@ -575,6 +749,56 @@
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structural Tag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"prompts = [text]\n",
"\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.8,\n",
" \"top_p\": 0.95,\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" }\n",
" ),\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
......
...@@ -38,7 +38,7 @@ runtime_common = [ ...@@ -38,7 +38,7 @@ runtime_common = [
"transformers==4.48.3", "transformers==4.48.3",
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.15", "xgrammar==0.1.16",
] ]
srt = [ srt = [
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, root_validator
from typing_extensions import Literal from typing_extensions import Literal
...@@ -323,6 +323,15 @@ class ChatCompletionRequest(BaseModel): ...@@ -323,6 +323,15 @@ class ChatCompletionRequest(BaseModel):
default="auto", examples=["none"] default="auto", examples=["none"]
) # noqa ) # noqa
@root_validator(pre=True)
def set_tool_choice_default(cls, values):
if values.get("tool_choice") is None:
if values.get("tools") is None:
values["tool_choice"] = "none"
else:
values["tool_choice"] = "auto"
return values
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1 top_k: int = -1
min_p: float = 0.0 min_p: float = 0.0
......
...@@ -125,7 +125,7 @@ class ServerArgs: ...@@ -125,7 +125,7 @@ class ServerArgs:
# Kernel backend # Kernel backend
attention_backend: Optional[str] = None attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = "outlines" grammar_backend: Optional[str] = "xgrammar"
# Speculative decoding # Speculative decoding
speculative_algorithm: Optional[str] = None speculative_algorithm: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment