Commit 909abb58 authored by maxiao's avatar maxiao
Browse files

adapt to sglang v0.5.2rc1 on dcu

parents
# Minimal Makefile for Sphinx documentation
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SPHINXAUTOBUILD ?= sphinx-autobuild
SOURCEDIR = .
BUILDDIR = _build
PORT ?= 8003
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@echo ""
@echo "Additional targets:"
@echo " serve to build and serve documentation with auto-build and live reload"
# Compile Notebook files and record execution time
compile:
@set -e; \
echo "Starting Notebook compilation..."; \
mkdir -p logs; \
echo "Notebook execution timings:" > logs/timing.log; \
START_TOTAL=$$(date +%s); \
find $(SOURCEDIR) -path "*/_build/*" -prune -o -name "*.ipynb" -print0 | \
parallel -0 -j3 --halt soon,fail=1 ' \
NB_NAME=$$(basename {}); \
START_TIME=$$(date +%s); \
retry --delay=0 --times=2 -- \
jupyter nbconvert --to notebook --execute --inplace "{}" \
--ExecutePreprocessor.timeout=600 \
--ExecutePreprocessor.kernel_name=python3; \
RET_CODE=$$?; \
END_TIME=$$(date +%s); \
ELAPSED_TIME=$$((END_TIME - START_TIME)); \
echo "$${NB_NAME}: $${ELAPSED_TIME}s" >> logs/timing.log; \
exit $$RET_CODE' || exit 1; \
END_TOTAL=$$(date +%s); \
TOTAL_ELAPSED=$$((END_TOTAL - START_TOTAL)); \
echo "---------------------------------" >> logs/timing.log; \
echo "Total execution time: $${TOTAL_ELAPSED}s" >> logs/timing.log; \
echo "All Notebook execution timings:" && cat logs/timing.log
# Serve documentation with auto-build and live reload
serve:
@echo "Starting auto-build server at http://0.0.0.0:$(PORT)"
@$(SPHINXAUTOBUILD) "$(SOURCEDIR)" "$(BUILDDIR)/html" \
--host 0.0.0.0 \
--port $(PORT) \
--watch $(SOURCEDIR) \
--re-ignore ".*\.(ipynb_checkpoints|pyc|pyo|pyd|git)"
.PHONY: help Makefile compile clean serve
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
clean:
find . -name "*.ipynb" -exec nbstripout {} \;
rm -rf $(BUILDDIR)
rm -rf logs
# SGLang Documentation
We recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase.
Most documentation files are located under the `docs/` folder.
## Docs Workflow
### Install Dependency
```bash
apt-get update && apt-get install -y pandoc parallel retry
pip install -r requirements.txt
```
### Update Documentation
Update your Jupyter notebooks in the appropriate subdirectories under `docs/`. If you add new files, remember to update `index.rst` (or relevant `.rst` files) accordingly.
- **`pre-commit run --all-files`** manually runs all configured checks, applying fixes if possible. If it fails the first time, re-run it to ensure lint errors are fully resolved. Make sure your code passes all checks **before** creating a Pull Request.
```bash
# 1) Compile all Jupyter notebooks
make compile # This step can take a long time (10+ mins). You can consider skipping this step if you can make sure your added files are correct.
make html
# 2) Compile and Preview documentation locally with auto-build
# This will automatically rebuild docs when files change
# Open your browser at the displayed port to view the docs
bash serve.sh
# 2a) Alternative ways to serve documentation
# Directly use make serve
make serve
# With custom port
PORT=8080 make serve
# 3) Clean notebook outputs
# nbstripout removes notebook outputs so your PR stays clean
pip install nbstripout
find . -name '*.ipynb' -exec nbstripout {} \;
# 4) Pre-commit checks and create a PR
# After these checks pass, push your changes and open a PR on your branch
pre-commit run --all-files
```
---
## Documentation Style Guidelines
- For common functionalities, we prefer **Jupyter Notebooks** over Markdown so that all examples can be executed and validated by our docs CI pipeline. For complex features (e.g., distributed serving), Markdown is preferred.
- Keep in mind the documentation execution time when writing interactive Jupyter notebooks. Each interactive notebook will be run and compiled against every commit to ensure they are runnable, so it is important to apply some tips to reduce the documentation compilation time:
- Use small models (e.g., `qwen/qwen2.5-0.5b-instruct`) for most cases to reduce server launch time.
- Reuse the launched server as much as possible to reduce server launch time.
- Do not use absolute links (e.g., `https://docs.sglang.ai/get_started/install.html`). Always prefer relative links (e.g., `../get_started/install.md`).
- Follow the existing examples to learn how to launch a server, send a query and other common styles.
.output_area {
color: #615656;
}
table.autosummary td {
width: 50%
}
img.align-center {
display: block;
margin-left: auto;
margin-right: auto;
}
.output_area.stderr {
color: #d3d3d3 !important;
}
.output_area.stdout {
color: #d3d3d3 !important;
}
div.output_area.stderr {
color: #d3d3d3 !important;
}
div.output_area.stdout {
color: #d3d3d3 !important;
}
table.autosummary td {
width: 50%
}
img.align-center {
display: block;
margin-left: auto;
margin-right: auto;
}
# Attention Backend
SGLang supports multiple attention backends. Each of them has different pros and cons.
You can test them according to your needs.
## Supporting matrix for different attention backends
| **Backend** | **Page Size > 1** | **Spec Decoding** | **MLA** | **Sliding Window** | **MultiModal** |
|--------------------------|-------------------|-------------------|---------|--------------------|----------------|
| **FlashInfer** | ❌ | ✅ | ✅ | ✅ | ✅ |
| **FA3** | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ |
| **Torch Native** | ❌ | ❌ | ✅ | ❌ | ❌ |
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
| **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ |
| **Ascend** | ✅ | ❌ | ✅ | ❌ | ❌ |
| **Wave** | ✅ | ❌ | ❌ | ❌ | ❌ |
**Notes:**
- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.
Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1.
## User guide
### Launch command for different attention backends.
- FlashInfer (Default for Non-Hopper Machines, e.g., A100, A40)
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend flashinfer
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --attention-backend flashinfer --trust-remote-code
```
- FlashAttention 3 (Default for Hopper Machines, e.g., H100, H200, H20)
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend fa3
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --trust-remote-code --attention-backend fa3
```
- Triton
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend triton
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --attention-backend triton --trust-remote-code
```
- Torch Native
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend torch_native
```
- FlashMLA
```bash
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
```
- TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200)
```bash
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
```
- TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint)
```bash
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --kv-cache-dtype fp8_e4m3 --trust-remote-code
```
- Ascend
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
```
- Wave
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend wave
```
## Steps to add a new attention backend
To add a new attention backend, you can learn from the existing backends
(`python/sglang/srt/layers/attention/triton_backend.py`, `python/sglang/srt/layers/attention/flashattention_backend.py`)
and follow the steps below.
1. Run without cuda graph. Support the two forward functions
- forward_extend
- Will be used for prefill, prefill with KV cache, and target verification
- It will be called once per layer
- forward_decode
- Will be used for normal decode, and draft decode
- It will be called once per layer
- init_forward_metadata
- Initialize the class and common metadata shared by all layers
- Call the plan function for optimizations like split_kv
- It will be called once per forward
2. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions
- init_cuda_graph_state
- It will be called once during life time
- Create all common shared buffers
- init_forward_metadata_capture_cuda_graph
- It will be called before capturing a cuda graph
- It is similar to init_forward_metadata but write the medatada to some pre-defined buffers
- init_forward_metadata_replay_cuda_graph
- It will be called before replaying a cuda graph
- This function is in the critical path and needs to be fast
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tool and Function Calling\n",
"\n",
"This guide demonstrates how to use SGLang’s [Function calling](https://platform.openai.com/docs/guides/function-calling) functionality."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI Compatible API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Launching the Server"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"from openai import OpenAI\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n",
")\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
"\n",
"- llama3: Llama 3.1 / 3.2 / 3.3 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.3-70B-Instruct).\n",
"- llama4: Llama 4 (e.g. meta-llama/Llama-4-Scout-17B-16E-Instruct).\n",
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html).\n",
"- deepseekv3: DeepSeek-v3 (e.g., deepseek-ai/DeepSeek-V3-0324).\n",
"- gpt-oss: GPT-OSS (e.g., openai/gpt-oss-120b, openai/gpt-oss-20b, lmsys/gpt-oss-120b-bf16, lmsys/gpt-oss-20b-bf16). Note: The gpt-oss tool parser filters out analysis channel events and only preserves normal text. This can cause the content to be empty when explanations are in the analysis channel. To work around this, complete the tool round by returning tool results as role=\"tool\" messages, which enables the model to generate the final content.\n",
"- kimi_k2: moonshotai/Kimi-K2-Instruct"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define Tools for Function Call\n",
"Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define tools\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
" },\n",
" \"state\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
" \" in, e.g. 'CA' which would mean 'California'\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The unit to fetch the temperature in\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" },\n",
" },\n",
" \"required\": [\"city\", \"state\", \"unit\"],\n",
" },\n",
" },\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define Messages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_messages():\n",
" return [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.\",\n",
" }\n",
" ]\n",
"\n",
"\n",
"messages = get_messages()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize the Client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize OpenAI-like client\n",
"client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n",
"model_name = client.models.list().data[0].id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Non-Streaming Request"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Non-streaming mode test\n",
"response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
" stream=False, # Non-streaming\n",
" tools=tools,\n",
")\n",
"print_highlight(\"Non-stream response:\")\n",
"print(response_non_stream)\n",
"print_highlight(\"==== content ====\")\n",
"print(response_non_stream.choices[0].message.content)\n",
"print_highlight(\"==== tool_calls ====\")\n",
"print(response_non_stream.choices[0].message.tool_calls)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Handle Tools\n",
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n",
"arguments_non_stream = (\n",
" response_non_stream.choices[0].message.tool_calls[0].function.arguments\n",
")\n",
"\n",
"print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n",
"print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming Request"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Streaming mode test\n",
"print_highlight(\"Streaming response:\")\n",
"response_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
" stream=True, # Enable streaming\n",
" tools=tools,\n",
")\n",
"\n",
"texts = \"\"\n",
"tool_calls = []\n",
"name = \"\"\n",
"arguments = \"\"\n",
"for chunk in response_stream:\n",
" if chunk.choices[0].delta.content:\n",
" texts += chunk.choices[0].delta.content\n",
" if chunk.choices[0].delta.tool_calls:\n",
" tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n",
"print_highlight(\"==== Text ====\")\n",
"print(texts)\n",
"\n",
"print_highlight(\"==== Tool Call ====\")\n",
"for tool_call in tool_calls:\n",
" print(tool_call)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Handle Tools\n",
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Parse and combine function call arguments\n",
"arguments = []\n",
"for tool_call in tool_calls:\n",
" if tool_call.function.name:\n",
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
"\n",
" if tool_call.function.arguments:\n",
" arguments.append(tool_call.function.arguments)\n",
"\n",
"# Combine all fragments into a single JSON string\n",
"full_arguments = \"\".join(arguments)\n",
"print_highlight(f\"streamed function call arguments: {full_arguments}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define a Tool Function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is a demonstration, define real function according to your usage.\n",
"def get_current_weather(city: str, state: str, unit: \"str\"):\n",
" return (\n",
" f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n",
" \"partly cloudly, with highs in the 90's.\"\n",
" )\n",
"\n",
"\n",
"available_tools = {\"get_current_weather\": get_current_weather}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"### Execute the Tool"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"messages.append(response_non_stream.choices[0].message)\n",
"\n",
"# Call the corresponding tool function\n",
"tool_call = messages[-1].tool_calls[0]\n",
"tool_name = tool_call.function.name\n",
"tool_to_call = available_tools[tool_name]\n",
"result = tool_to_call(**(json.loads(tool_call.function.arguments)))\n",
"print_highlight(f\"Function call result: {result}\")\n",
"# messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
"messages.append(\n",
" {\n",
" \"role\": \"tool\",\n",
" \"tool_call_id\": tool_call.id,\n",
" \"content\": str(result),\n",
" \"name\": tool_name,\n",
" }\n",
")\n",
"\n",
"print_highlight(f\"Updated message history: {messages}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Send Results Back to Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"final_response = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0,\n",
" top_p=0.95,\n",
" stream=False,\n",
" tools=tools,\n",
")\n",
"print_highlight(\"Non-stream response:\")\n",
"print(final_response)\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print(final_response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Native API and SGLang Runtime (SRT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"import requests\n",
"\n",
"# generate an answer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n",
"\n",
"messages = get_messages()\n",
"\n",
"input = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" tools=tools,\n",
")\n",
"\n",
"gen_url = f\"http://localhost:{port}/generate\"\n",
"gen_data = {\n",
" \"text\": input,\n",
" \"sampling_params\": {\n",
" \"skip_special_tokens\": False,\n",
" \"max_new_tokens\": 1024,\n",
" \"temperature\": 0,\n",
" \"top_p\": 0.95,\n",
" },\n",
"}\n",
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
"print_highlight(\"==== Response ====\")\n",
"print(gen_response)\n",
"\n",
"# parse the response\n",
"parse_url = f\"http://localhost:{port}/parse_function_call\"\n",
"\n",
"function_call_input = {\n",
" \"text\": gen_response,\n",
" \"tool_call_parser\": \"qwen25\",\n",
" \"tools\": tools,\n",
"}\n",
"\n",
"function_call_response = requests.post(parse_url, json=function_call_input)\n",
"function_call_response_json = function_call_response.json()\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print(function_call_response_json[\"normal_text\"])\n",
"print_highlight(\"==== Calls ====\")\n",
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"from sglang.srt.function_call.function_call_parser import FunctionCallParser\n",
"from sglang.srt.managers.io_struct import Tool, Function\n",
"\n",
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
"tokenizer = llm.tokenizer_manager.tokenizer\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
")\n",
"\n",
"# Note that for gpt-oss tool parser, adding \"no_stop_trim\": True\n",
"# to make sure the tool call token <call> is not trimmed.\n",
"\n",
"sampling_params = {\n",
" \"max_new_tokens\": 1024,\n",
" \"temperature\": 0,\n",
" \"top_p\": 0.95,\n",
" \"skip_special_tokens\": False,\n",
"}\n",
"\n",
"# 1) Offline generation\n",
"result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n",
"generated_text = result[\"text\"] # Assume there is only one prompt\n",
"\n",
"print(\"=== Offline Engine Output Text ===\")\n",
"print(generated_text)\n",
"\n",
"\n",
"# 2) Parse using FunctionCallParser\n",
"def convert_dict_to_tool(tool_dict: dict) -> Tool:\n",
" function_dict = tool_dict.get(\"function\", {})\n",
" return Tool(\n",
" type=tool_dict.get(\"type\", \"function\"),\n",
" function=Function(\n",
" name=function_dict.get(\"name\"),\n",
" description=function_dict.get(\"description\"),\n",
" parameters=function_dict.get(\"parameters\"),\n",
" ),\n",
" )\n",
"\n",
"\n",
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
"\n",
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"qwen25\")\n",
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
"\n",
"print(\"=== Parsing Result ===\")\n",
"print(\"Normal text portion:\", normal_text)\n",
"print(\"Function call portion:\")\n",
"for call in calls:\n",
" # call: ToolCallItem\n",
" print(f\" - tool name: {call.name}\")\n",
" print(f\" parameters: {call.parameters}\")\n",
"\n",
"# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Choice Mode\n",
"\n",
"SGLang supports OpenAI's `tool_choice` parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.\n",
"\n",
"### Supported Tool Choice Options\n",
"\n",
"- **`tool_choice=\"required\"`**: Forces the model to call at least one tool\n",
"- **`tool_choice={\"type\": \"function\", \"function\": {\"name\": \"specific_function\"}}`**: Forces the model to call a specific function\n",
"\n",
"### Backend Compatibility\n",
"\n",
"Tool choice is fully supported with the **Xgrammar backend**, which is the default grammar backend (`--grammar-backend xgrammar`). However, it may not be fully supported with other backends such as `outlines`.\n",
"\n",
"### Example: Required Tool Choice"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"from sglang.test.doc_patch import launch_server_cmd\n",
"\n",
"# Start a new server session for tool choice examples\n",
"server_process_tool_choice, port_tool_choice = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\"\n",
")\n",
"wait_for_server(f\"http://localhost:{port_tool_choice}\")\n",
"\n",
"# Initialize client for tool choice examples\n",
"client_tool_choice = OpenAI(\n",
" api_key=\"None\", base_url=f\"http://0.0.0.0:{port_tool_choice}/v1\"\n",
")\n",
"model_name_tool_choice = client_tool_choice.models.list().data[0].id\n",
"\n",
"# Example with tool_choice=\"required\" - forces the model to call a tool\n",
"messages_required = [\n",
" {\"role\": \"user\", \"content\": \"Hello, what is the capital of France?\"}\n",
"]\n",
"\n",
"# Define tools\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The unit to fetch the temperature in\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" },\n",
" },\n",
" \"required\": [\"city\", \"unit\"],\n",
" },\n",
" },\n",
" }\n",
"]\n",
"\n",
"response_required = client_tool_choice.chat.completions.create(\n",
" model=model_name_tool_choice,\n",
" messages=messages_required,\n",
" temperature=0,\n",
" max_tokens=1024,\n",
" tools=tools,\n",
" tool_choice=\"required\", # Force the model to call a tool\n",
")\n",
"\n",
"print_highlight(\"Response with tool_choice='required':\")\n",
"print(\"Content:\", response_required.choices[0].message.content)\n",
"print(\"Tool calls:\", response_required.choices[0].message.tool_calls)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: Specific Function Choice\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example with specific function choice - forces the model to call a specific function\n",
"messages_specific = [\n",
" {\"role\": \"user\", \"content\": \"What are the most attactive places in France?\"}\n",
"]\n",
"\n",
"response_specific = client_tool_choice.chat.completions.create(\n",
" model=model_name_tool_choice,\n",
" messages=messages_specific,\n",
" temperature=0,\n",
" max_tokens=1024,\n",
" tools=tools,\n",
" tool_choice={\n",
" \"type\": \"function\",\n",
" \"function\": {\"name\": \"get_current_weather\"},\n",
" }, # Force the model to call the specific get_current_weather function\n",
")\n",
"\n",
"print_highlight(\"Response with specific function choice:\")\n",
"print(\"Content:\", response_specific.choices[0].message.content)\n",
"print(\"Tool calls:\", response_specific.choices[0].message.tool_calls)\n",
"\n",
"if response_specific.choices[0].message.tool_calls:\n",
" tool_call = response_specific.choices[0].message.tool_calls[0]\n",
" print(f\"Called function: {tool_call.function.name}\")\n",
" print(f\"Arguments: {tool_call.function.arguments}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process_tool_choice)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pythonic Tool Call Format (Llama-3.2 / Llama-3.3 / Llama-4)\n",
"\n",
"Some Llama models (such as Llama-3.2-1B, Llama-3.2-3B, Llama-3.3-70B, and Llama-4) support a \"pythonic\" tool call format, where the model outputs function calls as Python code, e.g.:\n",
"\n",
"```python\n",
"[get_current_weather(city=\"San Francisco\", state=\"CA\", unit=\"celsius\")]\n",
"```\n",
"\n",
"- The output is a Python list of function calls, with arguments as Python literals (not JSON).\n",
"- Multiple tool calls can be returned in the same list:\n",
"```python\n",
"[get_current_weather(city=\"San Francisco\", state=\"CA\", unit=\"celsius\"),\n",
" get_current_weather(city=\"New York\", state=\"NY\", unit=\"fahrenheit\")]\n",
"```\n",
"\n",
"For more information, refer to Meta’s documentation on [Zero shot function calling](https://github.com/meta-llama/llama-models/blob/main/models/llama4/prompt_format.md#zero-shot-function-calling---system-message).\n",
"\n",
"Note that this feature is still under development on Blackwell.\n",
"\n",
"### How to enable\n",
"- Launch the server with `--tool-call-parser pythonic`\n",
"- You may also specify --chat-template with the improved template for the model (e.g., `--chat-template=examples/chat_template/tool_chat_template_llama4_pythonic.jinja`).\n",
"This is recommended because the model expects a special prompt format to reliably produce valid pythonic tool call outputs. The template ensures that the prompt structure (e.g., special tokens, message boundaries like `<|eom|>`, and function call delimiters) matches what the model was trained or fine-tuned on. If you do not use the correct chat template, tool calling may fail or produce inconsistent results.\n",
"\n",
"#### Forcing Pythonic Tool Call Output Without a Chat Template\n",
"If you don't want to specify a chat template, you must give the model extremely explicit instructions in your messages to enforce pythonic output. For example, for `Llama-3.2-1B-Instruct`, you need:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \" python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --tool-call-parser pythonic --tp 1\" # llama-3.2-1b-instruct\n",
")\n",
"wait_for_server(f\"http://localhost:{port}\")\n",
"\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_weather\",\n",
" \"description\": \"Get the current weather for a given location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The name of the city or location.\",\n",
" }\n",
" },\n",
" \"required\": [\"location\"],\n",
" },\n",
" },\n",
" },\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_tourist_attractions\",\n",
" \"description\": \"Get a list of top tourist attractions for a given city.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The name of the city to find attractions for.\",\n",
" }\n",
" },\n",
" \"required\": [\"city\"],\n",
" },\n",
" },\n",
" },\n",
"]\n",
"\n",
"\n",
"def get_messages():\n",
" return [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": (\n",
" \"You are a travel assistant. \"\n",
" \"When asked to call functions, ALWAYS respond ONLY with a python list of function calls, \"\n",
" \"using this format: [func_name1(param1=value1, param2=value2), func_name2(param=value)]. \"\n",
" \"Do NOT use JSON, do NOT use variables, do NOT use any other format. \"\n",
" \"Here is an example:\\n\"\n",
" '[get_weather(location=\"Paris\"), get_tourist_attractions(city=\"Paris\")]'\n",
" ),\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": (\n",
" \"I'm planning a trip to Tokyo next week. What's the weather like and what are some top tourist attractions? \"\n",
" \"Propose parallel tool calls at once, using the python list of function calls format as shown above.\"\n",
" ),\n",
" },\n",
" ]\n",
"\n",
"\n",
"messages = get_messages()\n",
"\n",
"client = openai.Client(base_url=f\"http://localhost:{port}/v1\", api_key=\"xxxxxx\")\n",
"model_name = client.models.list().data[0].id\n",
"\n",
"\n",
"response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0,\n",
" top_p=0.9,\n",
" stream=False, # Non-streaming\n",
" tools=tools,\n",
")\n",
"print_highlight(\"Non-stream response:\")\n",
"print(response_non_stream)\n",
"\n",
"response_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0,\n",
" top_p=0.9,\n",
" stream=True,\n",
" tools=tools,\n",
")\n",
"texts = \"\"\n",
"tool_calls = []\n",
"name = \"\"\n",
"arguments = \"\"\n",
"\n",
"for chunk in response_stream:\n",
" if chunk.choices[0].delta.content:\n",
" texts += chunk.choices[0].delta.content\n",
" if chunk.choices[0].delta.tool_calls:\n",
" tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n",
"\n",
"print_highlight(\"Streaming Response:\")\n",
"print_highlight(\"==== Text ====\")\n",
"print(texts)\n",
"\n",
"print_highlight(\"==== Tool Call ====\")\n",
"for tool_call in tool_calls:\n",
" print(tool_call)\n",
"\n",
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> **Note:** \n",
"> The model may still default to JSON if it was heavily finetuned on that format. Prompt engineering (including examples) is the only way to increase the chance of pythonic output if you are not using a chat template."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How to support a new model?\n",
"1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n",
"```\n",
"\tTOOLS_TAG_LIST = [\n",
"\t “<|plugin|>“,\n",
"\t “<function=“,\n",
"\t “<tool_call>“,\n",
"\t “<|python_tag|>“,\n",
"\t “[TOOL_CALLS]”\n",
"\t]\n",
"```\n",
"2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n",
"```\n",
" class NewModelDetector(BaseFormatDetector):\n",
"```\n",
"3. Add the new detector to the MultiFormatParser class that manages all the format detectors."
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
# Hyperparameter Tuning
## Achieving high throughput for offline batch inference
Achieving a large batch size is the most important thing for attaining high throughput in offline batch inference.
When the server is running at full load in a steady state, look for the following in the log:
```Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, cuda graph: True, gen throughput (token/s): 4594.01, #queue-req: 317```
### Adjust the request submission speed to control `#queue-req`
`#queue-req` indicates the number of requests in the queue.
If you frequently see `#queue-req: 0`, it suggests that your client code is submitting requests too slowly.
A healthy range for `#queue-req` is `100 - 2000`.
However, avoid making `#queue-req` too large, as this will increase the scheduling overhead on the server.
### Achieve a high `token usage`
`token usage` indicates the KV cache memory utilization of the server. `token usage > 0.9` means good utilization.
If you frequently see `token usage < 0.9` and `#queue-req > 0`, it means the server is too conservative about taking in new requests. You can decrease `--schedule-conservativeness` to a value like 0.3.
The case of a server being too conservative can happen when users send many requests with a large `max_new_tokens` but the requests stop very early due to EOS or stop strings.
On the other hand, if you see `token usage` very high and you frequently see warnings like
`KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_token_ratio: 0.9998 -> 1.0000`, you can increase `--schedule-conservativeness` to a value like 1.3.
If you see `KV cache pool is full. Retract requests.` occasionally but not frequently, it is okay.
### Tune `--mem-fraction-static` to increase KV cache pool capacity
SGLang allocates memory as follows:
Total memory usage = model weights + KV cache pool + CUDA graph buffers + activations
The `--mem-fraction-static` parameter determines how much memory is allocated to the first two components:
mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity
To support higher concurrency, you should maximize the KV cache pool capacity by setting `--mem-fraction-static` as high as possible while still reserving enough memory for activations and CUDA graph buffers.
SGLang uses simple heuristics to set the default value of `--mem-fraction-static`, but you can optimize it for your use cases.
As a rule of thumb, reserving 5–8 GB of memory for activations is typically sufficient. You can check this by inspecting the logs just before the server is ready.
Look for log entries like this:
```
[2025-08-11 17:17:03] max_total_num_tokens=665690, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=4096, context_len=65536, available_gpu_mem=13.50 GB
```
Check the `available_gpu_mem` value.
- If it is between 5–8 GB, the setting is good.
- If it is too high (e.g., 10 - 20 GB), increase `--mem-fraction-static` to allocate more memory to the KV cache.
- If it is too low, you risk out-of-memory (OOM) errors later, so decrease `--mem-fraction-static`.
Another straightforward approach is to increase `--mem-fraction-static` in increments of 0.01 until you encounter OOM errors for your workloads.
### Avoid out-of-memory errors by tuning `--chunked-prefill-size`, `--mem-fraction-static`, and `--max-running-requests`
If you encounter out-of-memory (OOM) errors, you can adjust the following parameters:
- If OOM occurs during prefill, try reducing `--chunked-prefill-size` to `4096` or `2048`. This saves memory but slows down the prefill speed for long prompts.
- If OOM occurs during decoding, try lowering `--max-running-requests`.
- You can also reduce `--mem-fraction-static` to a smaller value, such as 0.8 or 0.7. This decreases the memory usage of the KV cache memory pool and helps prevent OOM errors during both prefill and decoding. However, it limits maximum concurrency and reduces peak throughput.
### Tune `--cuda-graph-max-bs`
By default, CUDA graph is enabled only for small batch sizes (e.g., less than 160 or 256).
However, for some models, especially at large tensor parallelism sizes, CUDA graph can be useful for batch sizes up to 512 or 768.
Therefore, it may be beneficial to increase `--cuda-graph-max-bs` to a larger value.
Note that CUDA graph consumes more memory, so you may need to reduce `--mem-fraction-static` at the same time.
### Tune `--dp-size` and `--tp-size`
Data parallelism is better for throughput. When there is enough GPU memory, always favor data parallelism for throughput. Refer to [sglang router](../advanced_features/router.md) for a better data parallelism rather than using `dp_size` parameter.
### Try other options
- `torch.compile` accelerates small models on small batch sizes. You can enable it with `--enable-torch-compile`.
- Try other quantization (e.g. FP8 quantization with `--quantization fp8`)
- Try other parallelism strategies (e.g. [expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/)) or DP attention for deepseek models (with `--enable-dp-attention --dp-size 8`).
- If the workload has many shared prefixes, try `--schedule-policy lpm`. Here, `lpm` stands for longest prefix match. It reorders requests to encourage more cache hits but introduces more scheduling overhead.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LoRA Serving"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"SGLang enables the use of [LoRA adapters](https://arxiv.org/abs/2106.09685) with a base model. By incorporating techniques from [S-LoRA](https://arxiv.org/pdf/2311.03285) and [Punica](https://arxiv.org/pdf/2310.18547), SGLang can efficiently support multiple LoRA adapters for different sequences within a single batch of inputs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Arguments for LoRA Serving"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following server arguments are relevant for multi-LoRA serving:\n",
"\n",
"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
"\n",
"* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {\"lora_name\":str,\"lora_path\":str,\"pinned\":bool}.\n",
"\n",
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
"\n",
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
"\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we only support Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n",
"\n",
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
"\n",
"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage\n",
"\n",
"### Serving Single Adaptor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import requests\n",
"\n",
"from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import wait_for_server, terminate_process"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" --max-loras-per-batch 1 --lora-backend triton \\\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses the base model\n",
" \"lora_path\": [\"lora0\", None],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output 0: {response.json()[0]['text']}\")\n",
"print(f\"Output 1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Serving Multiple Adaptors"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output 0: {response.json()[0]['text']}\")\n",
"print(f\"Output 1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dynamic LoRA loading"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
"\n",
"When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"lora0_new = \"philschmid/code-llama-3-1-8b-text-to-sql-lora\" # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"\n",
"\n",
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
"# We are adding it here just to demonstrate usage.\n",
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --max-lora-rank 256\n",
" --lora-target-modules all\n",
" \"\"\"\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load adapter lora0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" \"lora_path\": lora0,\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load adapter lora1:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check inference output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unload lora0 and replace it with a different adapter:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" },\n",
")\n",
"\n",
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" \"lora_path\": lora0_new,\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check output again:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LoRA GPU Pinning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another advanced option is to specify adapters as `pinned` during loading. When an adapter is pinned, it is permanently assigned to one of the available GPU pool slots (as configured by `--max-loras-per-batch`) and will not be evicted from GPU memory during runtime. Instead, it remains resident until it is explicitly unloaded.\n",
"\n",
"This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n",
"\n",
"In the example below, we start a server with `lora1` loaded as pinned, `lora2` and `lora3` loaded as regular (unpinned) adapters. Please note that, we intentionally specify `lora2` and `lora3` in two different formats to demonstrate that both are supported."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --cuda-graph-max-bs 8 \\\n",
" --max-loras-per-batch 3 --lora-backend triton \\\n",
" --max-lora-rank 256 \\\n",
" --lora-target-modules all \\\n",
" --lora-paths \\\n",
" {\"lora_name\":\"lora0\",\"lora_path\":\"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\"pinned\":true} \\\n",
" {\"lora_name\":\"lora1\",\"lora_path\":\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\"} \\\n",
" lora2=philschmid/code-llama-3-1-8b-text-to-sql-lora\n",
" \"\"\"\n",
")\n",
"\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also specify adapter as pinned during dynamic adapter loading. In the example below, we reload `lora2` as pinned adapter:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" },\n",
")\n",
"\n",
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\",\n",
" \"pinned\": True, # Pin the adapter to GPU\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Verify that the results are expected:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\", \"lora2\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0 (pinned): \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")\n",
"print(f\"Output from lora2 (not pinned): \\n{response.json()[2]['text']}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Future Works\n",
"\n",
"The development roadmap for LoRA-related features can be found in this [issue](https://github.com/sgl-project/sglang/issues/2929). Other features, including Embedding Layer, Unified Paging, Cutlass backend are still under development."
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# Observability
## Production Metrics
SGLang exposes the following metrics via Prometheus. You can enable them by adding `--enable-metrics` when launching the server.
You can query them by:
```
curl http://localhost:30000/metrics
```
See [Production Metrics](../references/production_metrics.md) for more details.
## Logging
By default, SGLang does not log any request contents. You can log them by using `--log-requests`.
You can control the verbosity by using `--log-request-level`.
See [Logging](server_arguments.md#logging) for more details.
## Request Dump and Replay
You can dump all requests and replay them later for benchmarking or other purposes.
To start dumping, use the following command to send a request to a server:
```
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000 --dump-requests-folder /tmp/sglang_request_dump --dump-requests-threshold 100
```
The server will dump the requests into a pickle file for every 100 requests.
To replay the request dump, use `scripts/playground/replay_request_dump.py`.
## Crash Dump and Replay
Sometimes the server might crash, and you may want to debug the cause of the crash.
SGLang supports crash dumping, which will dump all requests from the 5 minutes before the crash, allowing you to replay the requests and debug the reason later.
To enable crash dumping, use `--crash-dump-folder /tmp/crash_dump`.
To replay the crash dump, use `scripts/playground/replay_request_dump.py`.
# PD Disaggregation
## Why and What is PD Disaggregation?
Large Language Model (LLM) inference comprises two distinct phases: **Prefill** and **Decode**. The Prefill phase is computation-intensive, processing the entire input sequence, while the Decode phase is memory-intensive, managing the Key-Value (KV) cache for token generation. Traditionally, these phases are handled within a unified engine, where combined scheduling of prefill and decode batches introduces inefficiencies. To address these challenges, we introduce **Prefill and Decoding (PD) Disaggregation** in SGLang.
### Issues with Unified Scheduling
The conventional unified engine, which processes prefill and decode batches together, results in two significant problems:
1. **Prefill Interruption**: Incoming prefill batches frequently interrupt ongoing decode batches, causing substantial delays in token generation.
2. **DP Attention Imbalance**: In data-parallel (DP) attention, one DP worker may process a prefill batch while another handles a decode batch simultaneously, leading to increased decode latency.
PD Disaggregation resolves these by separating the two stages, enabling tailored optimizations for each.
For the design details, please refer to [link](https://docs.google.com/document/d/1rQXJwKd5b9b1aOzLh98mnyMhBMhlxXA5ATZTHoQrwvc/edit?tab=t.0).
Currently, we support Mooncake and NIXL as the transfer engine.
## Router Integration
For deploying PD disaggregation at scale with load balancing and fault tolerance, SGLang provides a router. The router can distribute requests between prefill and decode instances using various routing policies. For detailed information on setting up routing with PD disaggregation, including configuration options and deployment patterns, see the [SGLang Router documentation](router.md#mode-3-prefill-decode-disaggregation).
## Mooncake
### Requirements
```bash
uv pip install mooncake-transfer-engine
```
### Usage
### Llama Single Node
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
# decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
```
### Advanced Configuration
PD Disaggregation with Mooncake supports the following environment variables for fine-grained control over system behavior.
#### Prefill Server Configuration
| Variable | Description | Default |
|:--------:|:-----------:|:--------:
| **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions |
| **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `300` |
If a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600` (10 minutes) to relax the timeout condition.
Please be aware that this setting will cause prefill instances to take a longer time to clean up the affected memory resources when a running decode node loses connection.
#### Decode Server Configuration
| Variable | Description | Default |
|:--------:|:-----------:|:--------:
| **`SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL`** | Interval (seconds) between health checks to prefill bootstrap servers | `5.0` |
| **`SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE`** | Consecutive heartbeat failures before marking prefill server offline | `2` |
| **`SGLANG_DISAGGREGATION_WAITING_TIMEOUT`** | Timeout (seconds) for receiving KV Cache after request initialization | `300` |
If a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600` (10 minutes) to relax the timeout condition.
## NIXL
### Requirements
Install via pip.
```bash
pip install nixl
```
Or build from source - may be required if you already have UCX installed.
```bash
git clone https://github.com/ai-dynamo/nixl.git
cd nixl
pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
```
### Usage
### Llama Single Node
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
# decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --moe-a2a-backend deepep --mem-fraction-static 0.8 --max-running-requests 128
```
## ASCEND
### Usage
Use ascend backend with [mf_adapter(download link)](https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com:443/sglang/mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl?AccessKeyId=HPUAXT4YM0U8JNTERLST&Expires=1783151861&Signature=3j10QDUjqk70enaq8lostYV2bEA%3D) and ASCEND_MF_STORE_URL being set
```bash
pip install mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl --force-reinstall
export ASCEND_MF_STORE_URL="tcp://xxx.xx.xxx.xxx:xxxx"
```
Use mooncake backend, more details can be found in mooncake section.
```bash
export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
```
### Llama Single Node
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend ascend
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend ascend
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-transfer-backend ascend --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 1 --node-rank 0 --tp-size 16
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-transfer-backend ascend --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 1 --node-rank 0 --tp-size 16
```
# Quantization
SGLang supports various quantization methods, including offline quantization and online dynamic quantization.
Offline quantization loads pre-quantized model weights directly during inference. This is required for quantization methods
such as GPTQ and AWQ, which collect and pre-compute various statistics from the original weights using the calibration dataset.
Online quantization dynamically computes scaling parameters—such as the maximum/minimum values of model weights—during runtime.
Like NVIDIA FP8 training's [delayed scaling](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8) mechanism, online quantization calculates the appropriate scaling factors
on-the-fly to convert high-precision weights into a lower-precision format.
**Note: For better performance, usability and convenience, offline quantization is recommended over online quantization.**
If you use a pre-quantized model, do not add `--quantization` to enable online quantization at the same time.
For popular pre-quantized models, please visit [ModelCloud](https://huggingface.co/collections/ModelCloud/vortex-673743382af0a52b2a8b9fe2)
or [NeuralMagic](https://huggingface.co/collections/neuralmagic) collections on HF for some
popular quality validated quantized models. Quantized models must be validated via benchmarks post-quantization
to guard against abnormal quantization loss regressions.
## Offline Quantization
To load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline,
there's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the
downloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.**
```bash
python3 -m sglang.launch_server \
--model-path hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 \
--port 30000 --host 0.0.0.0
```
Take note, if your model is **per-channel quantized (INT8 or FP8) with per-token dynamic quantization activation**, you can opt to include `--quantization w8a8_int8` or `--quantization w8a8_fp8` to invoke the corresponding CUTLASS int8_kernel or fp8_kernel in sgl-kernel. This action will ignore the Hugging Face config's quantization settings. For instance, with `neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic`, if you execute with `--quantization w8a8_fp8`, the system will use the `W8A8Fp8Config` from SGLang to invoke the sgl-kernel, rather than the `CompressedTensorsConfig` for vLLM kernels.
```bash
python3 -m sglang.launch_server \
--model-path neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic \
--quantization w8a8_fp8 \
--port 30000 --host 0.0.0.0
```
### Examples of Offline Model Quantization
#### Using [GPTQModel](https://github.com/ModelCloud/GPTQModel)
```bash
# install
pip install gptqmodel --no-build-isolation -v
```
```py
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit"
calibration_dataset = load_dataset(
"allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz",
split="train"
).select(range(1024))["text"]
quant_config = QuantizeConfig(bits=4, group_size=128) # quantization config
model = GPTQModel.load(model_id, quant_config) # load model
model.quantize(calibration_dataset, batch_size=2) # quantize
model.save(quant_path) # save model
```
#### Using [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
```bash
# install
pip install llmcompressor
```
Here, we take quantize `meta-llama/Meta-Llama-3-8B-Instruct` to `FP8` as an example to elaborate on how to do offline quantization.
```python
from transformers import AutoTokenizer
from llmcompressor.transformers import SparseAutoModelForCausalLM
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
# Step 1: Load the original model.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Step 2: Perform offline quantization.
# Step 2.1: Configure the simple PTQ quantization.
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])
# Step 2.2: Apply the quantization algorithm.
oneshot(model=model, recipe=recipe)
# Step 3: Save the model.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
```
Then, you can directly use the quantized model with `SGLang`, by using the following command:
```bash
python3 -m sglang.launch_server \
--model-path $PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic \
--port 30000 --host 0.0.0.0
```
## Online Quantization
To enable online quantization, you can simply specify `--quantization` in the command line. For example, you can launch the server with the following command to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`:
```bash
python3 -m sglang.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--quantization fp8 \
--port 30000 --host 0.0.0.0
```
Our team is working on supporting more online quantization methods. SGLang will soon support methods including but not limited to `["awq", "gptq", "marlin", "gptq_marlin", "awq_marlin", "bitsandbytes", "gguf"]`.
SGLang also supports quantization methods based on [torchao](https://github.com/pytorch/ao). You can simply specify `--torchao-config` in the command line to support this feature. For example, if you want to enable `int4wo-128` for model `meta-llama/Meta-Llama-3.1-8B-Instruct`, you can launch the server with the following command:
```bash
python3 -m sglang.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--torchao-config int4wo-128 \
--port 30000 --host 0.0.0.0
```
SGLang supports the following quantization methods based on torchao `["int8dq", "int8wo", "fp8wo", "fp8dq-per_tensor", "fp8dq-per_row", "int4wo-32", "int4wo-64", "int4wo-128", "int4wo-256"]`.
Note: According to [this issue](https://github.com/sgl-project/sglang/issues/2219#issuecomment-2561890230), `"int8dq"` method currently has some bugs when using together with cuda graph capture. So we suggest to disable cuda graph capture when using `"int8dq"` method. Namely, please use the following command:
```bash
python3 -m sglang.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--torchao-config int8dq \
--disable-cuda-graph \
--port 30000 --host 0.0.0.0
```
## Reference
- [GPTQModel](https://github.com/ModelCloud/GPTQModel)
- [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
# SGLang Router
The SGLang Router is a high-performance request distribution system that routes inference requests across multiple SGLang runtime instances. It features cache-aware load balancing, fault tolerance, and support for advanced deployment patterns including data parallelism and prefill-decode disaggregation.
## Key Features
- **Cache-Aware Load Balancing**: Optimizes cache utilization while maintaining balanced load distribution
- **Multiple Routing Policies**: Choose from random, round-robin, cache-aware, or power-of-two policies
- **Fault Tolerance**: Automatic retry and circuit breaker mechanisms for resilient operation
- **Dynamic Scaling**: Add or remove workers at runtime without service interruption
- **Kubernetes Integration**: Native service discovery and pod management
- **Prefill-Decode Disaggregation**: Support for disaggregated serving load balancing
- **Prometheus Metrics**: Built-in observability and monitoring
## Installation
```bash
pip install sglang-router
```
## Quick Start
To see all available options:
```bash
python -m sglang_router.launch_server --help # Co-launch router and workers
python -m sglang_router.launch_router --help # Launch router only
```
## Deployment Modes
The router supports three primary deployment patterns:
1. **Co-launch Mode**: Router and workers launch together (simplest for single-node deployments)
2. **Separate Launch Mode**: Router and workers launch independently (best for multi-node setups)
3. **Prefill-Decode Disaggregation**: Specialized mode for disaggregated serving
### Mode 1: Co-launch Router and Workers
This mode launches both the router and multiple worker instances in a single command. It's the simplest deployment option and replaces the `--dp-size` argument of SGLang Runtime.
```bash
# Launch router with 4 workers
python -m sglang_router.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--dp-size 4 \
--host 0.0.0.0 \
--port 30000
```
#### Sending Requests
Once the server is ready, send requests to the router endpoint:
```python
import requests
# Using the /generate endpoint
url = "http://localhost:30000/generate"
data = {
"text": "What is the capital of France?",
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 100
}
}
response = requests.post(url, json=data)
print(response.json())
# OpenAI-compatible endpoint
url = "http://localhost:30000/v1/chat/completions"
data = {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": [{"role": "user", "content": "What is the capital of France?"}]
}
response = requests.post(url, json=data)
print(response.json())
```
### Mode 2: Separate Launch Mode
This mode is ideal for multi-node deployments where workers run on different machines.
#### Step 1: Launch Workers
On each worker node:
```bash
# Worker node 1
python -m sglang.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--host 0.0.0.0 \
--port 8000
# Worker node 2
python -m sglang.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--host 0.0.0.0 \
--port 8001
```
#### Step 2: Launch Router
On the router node:
```bash
python -m sglang_router.launch_router \
--worker-urls http://worker1:8000 http://worker2:8001 \
--host 0.0.0.0 \
--port 30000 \
--policy cache_aware # or random, round_robin, power_of_two
```
### Mode 3: Prefill-Decode Disaggregation
This advanced mode separates prefill and decode operations for optimized performance:
```bash
python -m sglang_router.launch_router \
--pd-disaggregation \
--prefill http://prefill1:8000 9000 \
--prefill http://prefill2:8001 9001 \
--decode http://decode1:8002 \
--decode http://decode2:8003 \
--prefill-policy cache_aware \
--decode-policy round_robin
```
#### Understanding --prefill Arguments
The `--prefill` flag accepts URLs with optional bootstrap ports:
- `--prefill http://server:8000` - No bootstrap port
- `--prefill http://server:8000 9000` - Bootstrap port 9000
- `--prefill http://server:8000 none` - Explicitly no bootstrap port
#### Policy Inheritance in PD Mode
The router intelligently handles policy configuration for prefill and decode nodes:
1. **Only `--policy` specified**: Both prefill and decode nodes use this policy
2. **`--policy` and `--prefill-policy` specified**: Prefill nodes use `--prefill-policy`, decode nodes use `--policy`
3. **`--policy` and `--decode-policy` specified**: Prefill nodes use `--policy`, decode nodes use `--decode-policy`
4. **All three specified**: Prefill nodes use `--prefill-policy`, decode nodes use `--decode-policy` (main `--policy` is ignored)
Example with mixed policies:
```bash
python -m sglang_router.launch_router \
--pd-disaggregation \
--prefill http://prefill1:8000
--prefill http://prefill2:8000 \
--decode http://decode1:8001
--decode http://decode2:8001 \
--policy round_robin \
--prefill-policy cache_aware # Prefill uses cache_aware and decode uses round_robin from --policy
```
#### PD Mode with Service Discovery
For Kubernetes deployments with separate prefill and decode server pools:
```bash
python -m sglang_router.launch_router \
--pd-disaggregation \
--service-discovery \
--prefill-selector app=prefill-server tier=gpu \
--decode-selector app=decode-server tier=cpu \
--service-discovery-namespace production \
--prefill-policy cache_aware \
--decode-policy round_robin
```
## Dynamic Scaling
The router supports runtime scaling through REST APIs:
### Adding Workers
```bash
# Launch a new worker
python -m sglang.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 30001
# Add it to the router
curl -X POST "http://localhost:30000/add_worker?url=http://127.0.0.1:30001"
```
### Removing Workers
```bash
curl -X POST "http://localhost:30000/remove_worker?url=http://127.0.0.1:30001"
```
**Note**: When using cache-aware routing, removed workers are cleanly evicted from the routing tree and request queues.
## Fault Tolerance
The router includes comprehensive fault tolerance mechanisms:
### Retry Configuration
```bash
python -m sglang_router.launch_router \
--worker-urls http://worker1:8000 http://worker2:8001 \
--retry-max-retries 3 \
--retry-initial-backoff-ms 100 \
--retry-max-backoff-ms 10000 \
--retry-backoff-multiplier 2.0 \
--retry-jitter-factor 0.1
```
### Circuit Breaker
Protects against cascading failures:
```bash
python -m sglang_router.launch_router \
--worker-urls http://worker1:8000 http://worker2:8001 \
--cb-failure-threshold 5 \
--cb-success-threshold 2 \
--cb-timeout-duration-secs 30 \
--cb-window-duration-secs 60
```
**Behavior**:
- Worker is marked unhealthy after `cb-failure-threshold` consecutive failures
- Returns to service after `cb-success-threshold` successful health checks
- Circuit breaker can be disabled with `--disable-circuit-breaker`
## Routing Policies
The router supports multiple routing strategies:
### 1. Random Routing
Distributes requests randomly across workers.
```bash
--policy random
```
### 2. Round-Robin Routing
Cycles through workers in order.
```bash
--policy round_robin
```
### 3. Power of Two Choices
Samples two workers and routes to the less loaded one.
```bash
--policy power_of_two
```
### 4. Cache-Aware Load Balancing (Default)
The most sophisticated policy that combines cache optimization with load balancing:
```bash
--policy cache_aware \
--cache-threshold 0.5 \
--balance-abs-threshold 32 \
--balance-rel-threshold 1.0001
```
#### How It Works
1. **Load Assessment**: Checks if the system is balanced
- Imbalanced if: `(max_load - min_load) > balance_abs_threshold` AND `max_load > balance_rel_threshold * min_load`
2. **Routing Decision**:
- **Balanced System**: Uses cache-aware routing
- Routes to worker with highest prefix match if match > `cache_threshold`
- Otherwise routes to worker with most available cache capacity
- **Imbalanced System**: Uses shortest queue routing to the least busy worker
3. **Cache Management**:
- Maintains approximate radix trees per worker
- Periodically evicts LRU entries based on `--eviction-interval` and `--max-tree-size`
### Data Parallelism Aware Routing
Enables fine-grained control over data parallel replicas:
```bash
--dp-aware \
--api-key your_api_key # Required for worker authentication
```
This mode coordinates with SGLang's DP controller for optimized request distribution across data parallel ranks.
## Configuration Reference
### Core Settings
| Parameter | Type | Default | Description |
|-----------------------------|------|-------------|-----------------------------------------------------------------|
| `--host` | str | 127.0.0.1 | Router server host address |
| `--port` | int | 30000 | Router server port |
| `--worker-urls` | list | [] | Worker URLs for separate launch mode |
| `--policy` | str | cache_aware | Routing policy (random, round_robin, cache_aware, power_of_two) |
| `--max-concurrent-requests` | int | 64 | Maximum concurrent requests (rate limiting) |
| `--request-timeout-secs` | int | 600 | Request timeout in seconds |
| `--max-payload-size` | int | 256MB | Maximum request payload size |
### Cache-Aware Routing Parameters
| Parameter | Type | Default | Description |
|---------------------------|-------|----------|--------------------------------------------------------|
| `--cache-threshold` | float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
| `--balance-abs-threshold` | int | 32 | Absolute load difference threshold |
| `--balance-rel-threshold` | float | 1.0001 | Relative load ratio threshold |
| `--eviction-interval` | int | 60 | Seconds between cache eviction cycles |
| `--max-tree-size` | int | 16777216 | Maximum nodes in routing tree |
### Fault Tolerance Parameters
| Parameter | Type | Default | Description |
|------------------------------|-------|---------|---------------------------------------|
| `--retry-max-retries` | int | 3 | Maximum retry attempts per request |
| `--retry-initial-backoff-ms` | int | 100 | Initial retry backoff in milliseconds |
| `--retry-max-backoff-ms` | int | 10000 | Maximum retry backoff in milliseconds |
| `--retry-backoff-multiplier` | float | 2.0 | Backoff multiplier between retries |
| `--retry-jitter-factor` | float | 0.1 | Random jitter factor for retries |
| `--disable-retries` | flag | False | Disable retry mechanism |
| `--cb-failure-threshold` | int | 5 | Failures before circuit opens |
| `--cb-success-threshold` | int | 2 | Successes to close circuit |
| `--cb-timeout-duration-secs` | int | 30 | Circuit breaker timeout duration |
| `--cb-window-duration-secs` | int | 60 | Circuit breaker window duration |
| `--disable-circuit-breaker` | flag | False | Disable circuit breaker |
### Prefill-Decode Disaggregation Parameters
| Parameter | Type | Default | Description |
|-----------------------------------|------|---------|-------------------------------------------------------|
| `--pd-disaggregation` | flag | False | Enable PD disaggregated mode |
| `--prefill` | list | [] | Prefill server URLs with optional bootstrap ports |
| `--decode` | list | [] | Decode server URLs |
| `--prefill-policy` | str | None | Routing policy for prefill nodes (overrides --policy) |
| `--decode-policy` | str | None | Routing policy for decode nodes (overrides --policy) |
| `--worker-startup-timeout-secs` | int | 300 | Timeout for worker startup |
| `--worker-startup-check-interval` | int | 10 | Interval between startup checks |
### Kubernetes Integration
| Parameter | Type | Default | Description |
|---------------------------------|------|--------------------------|------------------------------------------------------|
| `--service-discovery` | flag | False | Enable Kubernetes service discovery |
| `--selector` | list | [] | Label selector for workers (key1=value1 key2=value2) |
| `--prefill-selector` | list | [] | Label selector for prefill servers in PD mode |
| `--decode-selector` | list | [] | Label selector for decode servers in PD mode |
| `--service-discovery-port` | int | 80 | Port for discovered pods |
| `--service-discovery-namespace` | str | None | Kubernetes namespace to watch |
| `--bootstrap-port-annotation` | str | sglang.ai/bootstrap-port | Annotation for bootstrap ports |
### Observability
| Parameter | Type | Default | Description |
|------------------------|------|-----------|-------------------------------------------------------|
| `--prometheus-port` | int | 29000 | Prometheus metrics port |
| `--prometheus-host` | str | 127.0.0.1 | Prometheus metrics host |
| `--log-dir` | str | None | Directory for log files |
| `--log-level` | str | info | Logging level (debug, info, warning, error, critical) |
| `--request-id-headers` | list | None | Custom headers for request tracing |
### CORS Configuration
| Parameter | Type | Default | Description |
|--------------------------|------|---------|----------------------|
| `--cors-allowed-origins` | list | [] | Allowed CORS origins |
## Advanced Features
### Kubernetes Service Discovery
Automatically discover and manage workers in Kubernetes:
#### Standard Mode
```bash
python -m sglang_router.launch_router \
--service-discovery \
--selector app=sglang-worker env=prod \
--service-discovery-namespace production \
--service-discovery-port 8000
```
#### Prefill-Decode Disaggregation Mode
```bash
python -m sglang_router.launch_router \
--pd-disaggregation \
--service-discovery \
--prefill-selector app=prefill-server env=prod \
--decode-selector app=decode-server env=prod \
--service-discovery-namespace production
```
**Note**: The `--bootstrap-port-annotation` (default: `sglang.ai/bootstrap-port`) is used to discover bootstrap ports for prefill servers in PD mode. Prefill pods should have this annotation set to their bootstrap port value.
### Prometheus Metrics
Expose metrics for monitoring:
```bash
python -m sglang_router.launch_router \
--worker-urls http://worker1:8000 http://worker2:8001 \
--prometheus-port 29000 \
--prometheus-host 0.0.0.0
```
Metrics available at `http://localhost:29000/metrics`
### Request Tracing
Enable request ID tracking:
```bash
python -m sglang_router.launch_router \
--worker-urls http://worker1:8000 http://worker2:8001 \
--request-id-headers x-request-id x-trace-id
```
## Troubleshooting
### Common Issues
1. **Workers not connecting**: Ensure workers are fully initialized before starting the router. Use `--worker-startup-timeout-secs` to increase wait time.
2. **High latency**: Check if cache-aware routing is causing imbalance. Try adjusting `--balance-abs-threshold` and `--balance-rel-threshold`.
3. **Memory growth**: Reduce `--max-tree-size` or decrease `--eviction-interval` for more aggressive cache cleanup.
4. **Circuit breaker triggering frequently**: Increase `--cb-failure-threshold` or extend `--cb-window-duration-secs`.
### Debug Mode
Enable detailed logging:
```bash
python -m sglang_router.launch_router \
--worker-urls http://worker1:8000 http://worker2:8001 \
--log-level debug \
--log-dir ./router_logs
```
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reasoning Parser\n",
"\n",
"SGLang supports parsing reasoning content out from \"normal\" content for reasoning models such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).\n",
"\n",
"## Supported Models & Parsers\n",
"\n",
"| Model | Reasoning tags | Parser | Notes |\n",
"|---------|-----------------------------|------------------|-------|\n",
"| [DeepSeek‑R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `<think>` … `</think>` | `deepseek-r1` | Supports all variants (R1, R1-0528, R1-Distill) |\n",
"| [DeepSeek‑V3.1](https://huggingface.co/deepseek-ai/DeepSeek-V3.1) | `<think>` … `</think>` | `deepseek-v3` | Supports `thinking` parameter |\n",
"| [Standard Qwen3 models](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `<think>` … `</think>` | `qwen3` | Supports `enable_thinking` parameter |\n",
"| [Qwen3-Thinking models](https://huggingface.co/Qwen/Qwen3-235B-A22B-Thinking-2507) | `<think>` … `</think>` | `qwen3` or `qwen3-thinking` | Always generates thinking content |\n",
"| [Kimi models](https://huggingface.co/moonshotai/models) | `◁think▷` … `◁/think▷` | `kimi` | Uses special thinking delimiters |\n",
"| [GPT OSS](https://huggingface.co/openai/gpt-oss-120b) | `<\\|channel\\|>analysis<\\|message\\|>` … `<\\|end\\|>` | `gpt-oss` | N/A |\n",
"### Model-Specific Behaviors\n",
"\n",
"**DeepSeek-R1 Family:**\n",
"- DeepSeek-R1: No `<think>` start tag, jumps directly to thinking content\n",
"- DeepSeek-R1-0528: Generates both `<think>` start and `</think>` end tags\n",
"- Both are handled by the same `deepseek-r1` parser\n",
"\n",
"**DeepSeek-V3 Family:**\n",
"- DeepSeek-V3.1: Hybrid model supporting both thinking and non-thinking modes, use the `deepseek-v3` parser and `thinking` parameter (NOTE: not `enable_thinking`)\n",
"\n",
"**Qwen3 Family:**\n",
"- Standard Qwen3 (e.g., Qwen3-2507): Use `qwen3` parser, supports `enable_thinking` in chat templates\n",
"- Qwen3-Thinking (e.g., Qwen3-235B-A22B-Thinking-2507): Use `qwen3` or `qwen3-thinking` parser, always thinks\n",
"\n",
"**Kimi:**\n",
"- Kimi: Uses special `◁think▷` and `◁/think▷` tags\n",
"\n",
"**GPT OSS:**\n",
"- GPT OSS: Uses special `<|channel|>analysis<|message|>` and `<|end|>` tags"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage\n",
"\n",
"### Launching the Server"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Specify the `--reasoning-parser` option."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from openai import OpenAI\n",
"from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that `--reasoning-parser` defines the parser used to interpret responses."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### OpenAI Compatible API\n",
"\n",
"Using the OpenAI compatible API, the contract follows the [DeepSeek API design](https://api-docs.deepseek.com/guides/reasoning_model) established with the release of DeepSeek-R1:\n",
"\n",
"- `reasoning_content`: The content of the CoT.\n",
"- `content`: The content of the final answer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize OpenAI-like client\n",
"client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n",
"model_name = client.models.list().data[0].id\n",
"\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"What is 1+3?\",\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Non-Streaming Request"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" stream=False, # Non-streaming\n",
" extra_body={\"separate_reasoning\": True},\n",
")\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(response_non_stream.choices[0].message.reasoning_content)\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(response_non_stream.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Streaming Request"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" stream=True, # Non-streaming\n",
" extra_body={\"separate_reasoning\": True},\n",
")\n",
"\n",
"reasoning_content = \"\"\n",
"content = \"\"\n",
"for chunk in response_stream:\n",
" if chunk.choices[0].delta.content:\n",
" content += chunk.choices[0].delta.content\n",
" if chunk.choices[0].delta.reasoning_content:\n",
" reasoning_content += chunk.choices[0].delta.reasoning_content\n",
"\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(reasoning_content)\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" stream=True, # Non-streaming\n",
" extra_body={\"separate_reasoning\": True, \"stream_reasoning\": False},\n",
")\n",
"\n",
"reasoning_content = \"\"\n",
"content = \"\"\n",
"for chunk in response_stream:\n",
" if chunk.choices[0].delta.content:\n",
" content += chunk.choices[0].delta.content\n",
" if chunk.choices[0].delta.reasoning_content:\n",
" reasoning_content += chunk.choices[0].delta.reasoning_content\n",
"\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(reasoning_content)\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The reasoning separation is enable by default when specify . \n",
"**To disable it, set the `separate_reasoning` option to `False` in request.**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" stream=False, # Non-streaming\n",
" extra_body={\"separate_reasoning\": False},\n",
")\n",
"\n",
"print_highlight(\"==== Original Output ====\")\n",
"print_highlight(response_non_stream.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SGLang Native API "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n",
"input = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n",
"gen_url = f\"http://localhost:{port}/generate\"\n",
"gen_data = {\n",
" \"text\": input,\n",
" \"sampling_params\": {\n",
" \"skip_special_tokens\": False,\n",
" \"max_new_tokens\": 1024,\n",
" \"temperature\": 0.6,\n",
" \"top_p\": 0.95,\n",
" },\n",
"}\n",
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
"\n",
"print_highlight(\"==== Original Output ====\")\n",
"print_highlight(gen_response)\n",
"\n",
"parse_url = f\"http://localhost:{port}/separate_reasoning\"\n",
"separate_reasoning_data = {\n",
" \"text\": gen_response,\n",
" \"reasoning_parser\": \"deepseek-r1\",\n",
"}\n",
"separate_reasoning_response_json = requests.post(\n",
" parse_url, json=separate_reasoning_data\n",
").json()\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(separate_reasoning_response_json[\"reasoning_text\"])\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(separate_reasoning_response_json[\"text\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"from sglang.srt.parser.reasoning_parser import ReasoningParser\n",
"from sglang.utils import print_highlight\n",
"\n",
"llm = sgl.Engine(model_path=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n",
"input = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"sampling_params = {\n",
" \"max_new_tokens\": 1024,\n",
" \"skip_special_tokens\": False,\n",
" \"temperature\": 0.6,\n",
" \"top_p\": 0.95,\n",
"}\n",
"result = llm.generate(prompt=input, sampling_params=sampling_params)\n",
"\n",
"generated_text = result[\"text\"] # Assume there is only one prompt\n",
"\n",
"print_highlight(\"==== Original Output ====\")\n",
"print_highlight(generated_text)\n",
"\n",
"parser = ReasoningParser(\"deepseek-r1\")\n",
"reasoning_text, text = parser.parse_non_stream(generated_text)\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(reasoning_text)\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Supporting New Reasoning Model Schemas\n",
"\n",
"For future reasoning models, you can implement the reasoning parser as a subclass of `BaseReasoningFormatDetector` in `python/sglang/srt/reasoning_parser.py` and specify the reasoning parser for new reasoning model schemas accordingly."
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
# Server Arguments
This page provides a list of server arguments used in the command line to configure the behavior
and performance of the language model server during deployment. These arguments enable users to
customize key aspects of the server, including model selection, parallelism policies,
memory management, and optimization techniques.
You can find all arguments by `python3 -m sglang.launch_server --help`
## Common launch commands
- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
```bash
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2
```
- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. We recommend [SGLang Router](../advanced_features/router.md) for data parallelism.
```bash
python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2
```
- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`.
```bash
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7
```
- See [hyperparameter tuning](hyperparameter_tuning.md) on tuning hyperparameters for better performance.
- For docker and Kubernetes runs, you need to set up shared memory which is used for communication between processes. See `--shm-size` for docker and `/dev/shm` size update for Kubernetes manifests.
- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size.
```bash
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
```
- To enable `torch.compile` acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. By default, the cache path is located at `/tmp/torchinductor_root`, you can customize it using environment variable `TORCHINDUCTOR_CACHE_DIR`. For more details, please refer to [PyTorch official documentation](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) and [Enabling cache for torch.compile](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile).
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports other [quantization strategies (INT8/FP8)](https://github.com/sgl-project/sglang/blob/v0.3.6/python/sglang/srt/server_args.py#L671) as well.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](../references/custom_chat_template.md).
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`
```bash
# Node 0
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 0
# Node 1
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 1
```
Please consult the documentation below and [server_args.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py) to learn more about the arguments you may provide when launching a server.
## Model and tokenizer
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--model-path` | The path of the model weights. This can be a local folder or a Hugging Face repo ID. | None |
| `--tokenizer-path` | The path of the tokenizer. | None |
| `--tokenizer-mode` | Tokenizer mode. 'auto' will use the fast tokenizer if available, and 'slow' will always use the slow tokenizer. | auto |
| `--skip-tokenizer-init` | If set, skip init tokenizer and pass input_ids in generate request. | False |
| `--load-format` | The format of the model weights to load. 'auto' will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. 'pt' will load the weights in the pytorch bin format. 'safetensors' will load the weights in the safetensors format. 'npcache' will load the weights in pytorch format and store a numpy cache to speed up the loading. 'dummy' will initialize the weights with random values, which is mainly for profiling. 'gguf' will load the weights in the gguf format. 'bitsandbytes' will load the weights using bitsandbytes quantization. 'layered' loads weights layer by layer so that one can quantize a layer before loading another to make the peak memory envelope smaller. | auto |
| `--trust-remote-code` | Whether or not to allow for custom models defined on the Hub in their own modeling files. | False |
| `--context-length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). | None |
| `--is-embedding` | Whether to use a CausalLM as an embedding model. | False |
| `--enable-multimodal` | Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen. | None |
| `--revision` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | None |
| `--model-impl` | Which implementation of the model to use. 'auto' will try to use the SGLang implementation if it exists and fall back to the Transformers implementation if no SGLang implementation is available. 'sglang' will use the SGLang model implementation. 'transformers' will use the Transformers model implementation. | auto |
## HTTP server
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--host` | The host address for the server. | 127.0.0.1 |
| `--port` | The port number for the server. | 30000 |
| `--skip-server-warmup` | If set, skip the server warmup process. | False |
| `--warmups` | Warmup configurations. | None |
| `--nccl-port` | The port for NCCL initialization. | None |
## Quantization and data type
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto |
| `--quantization` | The quantization method. | None |
| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None |
| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto |
## Memory and scheduling
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--mem-fraction-static` | The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors. | None |
| `--max-running-requests` | The maximum number of running requests. | None |
| `--max-total-tokens` | The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes. | None |
| `--chunked-prefill-size` | The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill. | None |
| `--max-prefill-tokens` | The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length. | 16384 |
| `--schedule-policy` | The scheduling policy of the requests. | fcfs |
| `--schedule-conservativeness` | How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently. | 1.0 |
| `--cpu-offload-gb` | How many GBs of RAM to reserve for CPU offloading. | 0 |
| `--page-size` | The number of tokens in a page. | 1 |
## Runtime options
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None |
| `--tp-size` | The tensor parallelism size. | 1 |
| `--pp-size` | The pipeline parallelism size. | 1 |
| `--max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | None |
| `--stream-interval` | The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher. | 1 |
| `--stream-output` | Whether to output as a sequence of disjoint segments. | False |
| `--random-seed` | The random seed. | None |
| `--constrained-json-whitespace-pattern` | Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*. | None |
| `--watchdog-timeout` | Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging. | 300 |
| `--dist-timeout` | Set timeout for torch.distributed initialization. | None |
| `--download-dir` | Model download directory for huggingface. | None |
| `--base-gpu-id` | The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine. | 0 |
| `--gpu-id-step` | The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,.... | 1 |
| `--sleep-on-idle` | Reduce CPU usage when sglang is idle. | False |
## Logging
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--log-level` | The logging level of all loggers. | info |
| `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None |
| `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
| `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | 0 |
| `--show-time-cost` | Show time cost of custom marks. | False |
| `--enable-metrics` | Enable log prometheus metrics. | False |
| `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None |
| `--bucket-inter-token-latency` | The buckets of inter-token latency, specified as a list of floats. | None |
| `--bucket-e2e-request-latency` | The buckets of end-to-end request latency, specified as a list of floats. | None |
| `--collect-tokens-histogram` | Collect prompt/generation tokens histogram. | False |
| `--kv-events-config` | Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used. | None |
| `--decode-log-interval` | The log interval of decode batch. | 40 |
| `--enable-request-time-stats-logging` | Enable per request time stats logging. | False |
## API related
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--api-key` | Set API key of the server. It is also used in the OpenAI API compatible server. | None |
| `--served-model-name` | Override the model name returned by the v1/models endpoint in OpenAI API server. | None |
| `--chat-template` | The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. | None |
| `--completion-template` | The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently. | None |
| `--file-storage-path` | The path of the file storage in backend. | sglang_storage |
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
| `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'. | None |
## Data parallelism
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--dp-size` | The data parallelism size. | 1 |
| `--load-balance-method` | The load balancing strategy for data parallelism. Options include: 'round_robin', 'minimum_tokens'. The Minimum Token algorithm can only be used when DP attention is applied. This algorithm performs load balancing based on the real-time token load of the DP workers. | round_robin |
## Multi-node distributed serving
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--dist-init-addr` | The host address for initializing distributed backend (e.g., `192.168.0.2:25000`). | None |
| `--nnodes` | The number of nodes. | 1 |
| `--node-rank` | The node rank. | 0 |
## Model override args in JSON
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--json-model-override-args` | A dictionary in JSON string format used to override default model configurations. | {} |
| `--preferred-sampling-params` | json-formatted sampling settings that will be returned in /get_model_info. | None |
## LoRA
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
| `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool} | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
## Kernel backend
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--attention-backend` | Choose the kernels for attention layers. | None |
| `--prefill-attention-backend` | (Experimental) This argument specifies the backend for prefill attention computation. Note that this argument has priority over `attention_backend`. | None |
| `--decode-attention-backend` | (Experimental) This argument specifies the backend for decode attention computation. Note that this argument has priority over `attention_backend`. | None |
| `--sampling-backend` | Choose the kernels for sampling layers. | None |
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
| `--mm-attention-backend` | Set multimodal attention backend. | None |
## Speculative decoding
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--speculative-algorithm` | Speculative algorithm. | None |
| `--speculative-draft-model-path` | The path of the draft model weights. This can be a local folder or a Hugging Face repo ID. | None |
| `--speculative-num-steps` | The number of steps sampled from draft model in Speculative Decoding. | None |
| `--speculative-eagle-topk` | The number of tokens sampled from the draft model in eagle2 each step. | None |
| `--speculative-num-draft-tokens` | The number of tokens sampled from the draft model in Speculative Decoding. | None |
| `--speculative-accept-threshold-single` | Accept a draft token if its probability in the target model is greater than this threshold. | 1.0 |
| `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | 1.0 |
| `--speculative-token-map` | The path of the draft model's small vocab table. | None |
## Expert parallelism
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
| `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
| `--init-expert-location` | Initial location of EP experts. | trivial |
| `--enable-eplb` | Enable EPLB algorithm. | False |
| `--eplb-algorithm` | Chosen EPLB algorithm. | auto |
| `--eplb-rebalance-num-iterations` | Number of iterations to automatically trigger a EPLB re-balance. | 1000 |
| `--eplb-rebalance-layers-per-chunk` | Number of layers to rebalance per forward pass. | None |
| `--expert-distribution-recorder-mode` | Mode of expert distribution recorder. | None |
| `--expert-distribution-recorder-buffer-size` | Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer. | None |
| `--enable-expert-distribution-metrics` | Enable logging metrics for expert balancedness. | False |
| `--deepep-config` | Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path. | None |
| `--moe-dense-tp-size` | TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports. | None |
## Hierarchical cache
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--enable-hierarchical-cache` | Enable hierarchical cache. | False |
| `--hicache-ratio` | The ratio of the size of host KV cache memory pool to the size of device pool. | 2.0 |
| `--hicache-size` | The size of the hierarchical cache. | 0 |
| `--hicache-write-policy` | The write policy for hierarchical cache. | write_through |
| `--hicache-io-backend` | The IO backend for hierarchical cache. | |
| `--hicache-storage-backend` | The storage backend for hierarchical cache. | None |
## Optimization/debug options
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--disable-radix-cache` | Disable RadixAttention for prefix caching. | False |
| `--cuda-graph-max-bs` | Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value. | None |
| `--cuda-graph-bs` | Set the list of batch sizes for cuda graph. | None |
| `--disable-cuda-graph` | Disable cuda graph. | False |
| `--disable-cuda-graph-padding` | Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. | False |
| `--enable-profile-cuda-graph` | Enable profiling of cuda graph capture. | False |
| `--enable-nccl-nvls` | Enable NCCL NVLS for prefill heavy requests when available. | False |
| `--enable-symm-mem` | Enable NCCL symmetric memory for fast collectives. | False |
| `--enable-tokenizer-batch-encode` | Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds. | False |
| `--disable-outlines-disk-cache` | Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency. | False |
| `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False |
| `--enable-mscclpp` | Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL. | False |
| `--disable-overlap-schedule` | Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker. | False |
| `--enable-mixed-chunk` | Enabling mixing prefill and decode in a batch when using chunked prefill. | False |
| `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False |
| `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False |
| `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False |
| `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 |
| `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False |
| `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 |
| `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | |
| `--enable-nan-detection` | Enable the NaN detection for debugging purposes. | False |
| `--enable-p2p-check` | Enable P2P check for GPU access, otherwise the p2p access is allowed by default. | False |
| `--triton-attention-reduce-in-fp32` | Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16. This only affects Triton attention kernels. | False |
| `--triton-attention-num-kv-splits` | The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8. | 8 |
| `--num-continuous-decode-steps` | Run multiple continuous decoding steps to reduce scheduling overhead. This can potentially increase throughput but may also increase time-to-first-token latency. The default value is 1, meaning only run one decoding step at a time. | 1 |
| `--delete-ckpt-after-loading` | Delete the model checkpoint after loading the model. | False |
| `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False |
| `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False |
| `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False |
| `--flashinfer-mla-disable-ragged` | Disable ragged processing in Flashinfer MLA. | False |
| `--disable-shared-experts-fusion` | Disable shared experts fusion. | False |
| `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False |
| `--disable-fast-image-processor` | Disable fast image processor. | False |
| `--enable-return-hidden-states` | Enable returning hidden states. | False |
## Debug tensor dumps
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None |
| `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None |
| `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False |
| `--debug-tensor-dump-prefill-only` | Enable prefill-only mode for debug tensor dumps. | False |
## PD disaggregation
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--disaggregation-mode` | PD disaggregation mode: "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only). | null |
| `--disaggregation-transfer-backend` | The transfer backend for PD disaggregation. | mooncake |
| `--disaggregation-bootstrap-port` | The bootstrap port for PD disaggregation. | 8998 |
| `--disaggregation-decode-tp` | The decode TP for PD disaggregation. | None |
| `--disaggregation-decode-dp` | The decode DP for PD disaggregation. | None |
| `--disaggregation-prefill-pp` | The prefill PP for PD disaggregation. | 1 |
## Model weight update
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--custom-weight-loader` | Custom weight loader paths. | None |
| `--weight-loader-disable-mmap` | Disable mmap for weight loader. | False |
## PD-Multiplexing
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--enable-pdmux` | Enable PD-Multiplexing. | False |
| `--sm-group-num` | Number of SM groups for PD-Multiplexing. | 3 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Speculative Decoding\n",
"\n",
"SGLang now provides an EAGLE-based (EAGLE-2/EAGLE-3) speculative decoding option. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n",
"\n",
"### Performance Highlights\n",
"\n",
"Please see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be achieved via EAGLE3 decoding.\n",
"For further details please see the [EAGLE3 paper](https://arxiv.org/pdf/2503.01840).\n",
"\n",
"| Method | Throughput (tokens/s) |\n",
"|--------|----------------|\n",
"| SGLang (w/o speculative, 1x H100) | 158.34 tokens/s |\n",
"| SGLang + EAGLE-2 (1x H100) | 244.10 tokens/s |\n",
"| SGLang + EAGLE-3 (1x H100) | 373.25 tokens/s |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## EAGLE Decoding\n",
"\n",
"To enable EAGLE speculative decoding the following parameters are relevant:\n",
"* `speculative_draft_model_path`: Specifies draft model. This parameter is required.\n",
"* `speculative_num_steps`: Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. Default is 5.\n",
"* `speculative_eagle_topk`: Branching factor per step. Improves candidate diversity, will lead to higher acceptance rate, but more lead to higher memory/compute consumption. Default is 4.\n",
"* `speculative_num_draft_tokens`: Maximum parallel verification capacity. Allows deeper tree evaluation but will lead to higher GPU memory usage. Default is 8.\n",
"\n",
"These parameters are the same for EAGLE-2 and EAGLE-3.\n",
"\n",
"You can find the best combinations of these parameters with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py).\n",
"\n",
"In the documentation below, we set `--cuda-graph-max-bs` to be a small value for faster engine startup. For your own workloads, please tune the above parameters together with `--cuda-graph-max-bs`, `--max-running-requests`, `--mem-fraction-static` for the best performance. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EAGLE-2 decoding\n",
"\n",
"You can enable EAGLE-2 decoding by setting `--speculative-algorithm EAGLE` and choosing an appropriate model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"import openai"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 \\\n",
" --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Llama-2-7b-chat-hf\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
")\n",
"\n",
"print_highlight(f\"Response: {response}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EAGLE-2 Decoding with `torch.compile`\n",
"\n",
"You can also enable `torch.compile` for further optimizations and optionally set `--torch-compile-max-bs`:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
" --enable-torch-compile --torch-compile-max-bs 2\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Llama-2-7b-chat-hf\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
")\n",
"\n",
"print_highlight(f\"Response: {response}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling\n",
"\n",
"By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, checkout [the paper](https://arxiv.org/pdf/arXiv:2502.14856).\n",
"\n",
"In our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency token in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency token by directly downloading these token from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n",
"\n",
"Thanks for the contribution from [Weilin Zhao](https://github.com/Achazwl) and [Zhousx](https://github.com/Zhou-sx). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n",
" --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
")\n",
"\n",
"print_highlight(f\"Response: {response}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EAGLE-3 Decoding\n",
"\n",
"You can enable EAGLE-3 decoding by setting `--speculative-algorithm EAGLE3` and choosing an appropriate model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --speculative-algorithm EAGLE3 \\\n",
" --speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 32 --mem-fraction 0.6 \\\n",
" --cuda-graph-max-bs 2 --dtype float16\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
")\n",
"\n",
"print_highlight(f\"Response: {response}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multi Token Prediction\n",
"\n",
"We support [MTP(Multi-Token Prediction)](https://arxiv.org/pdf/2404.19737) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to [deepseek doc](../basic_usage/deepseek.md#multi-token-prediction))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \\\n",
" --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \\\n",
" --mem-fraction 0.5\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"\n",
"url = f\"http://localhost:{port}/v1/chat/completions\"\n",
"\n",
"data = {\n",
" \"model\": \"XiaomiMiMo/MiMo-7B-RL\",\n",
" \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}],\n",
"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"\n",
"EAGLE process is as follows:\n",
"\n",
"- Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence $(f_1, ..., f_k)$ and the token sequence $(t_2, ..., t_{k+1})$. \n",
"- The next token is then sampled from $p_{k+2}=\\text{LMHead}(f_{k+1})$. Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the `speculative_eagle_topk` parameter—to ensure a more coherent connection of context, and are given as input again.\n",
"- EAGLE-2 additionally uses the draft model to evaluate how probable certain branches in the draft tree are, dynamically stopping the expansion of unlikely branches. After the expansion phase, reranking is employed to select only the top `speculative_num_draft_tokens` final nodes as draft tokens.\n",
"- EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.\n",
"\n",
"This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionally to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see [EAGLE-2](https://arxiv.org/abs/2406.16858) and [EAGLE-3](https://arxiv.org/abs/2503.01840) paper.\n",
"\n",
"\n",
"For guidance how to train your own EAGLE model please see the [EAGLE repo](https://github.com/SafeAILab/EAGLE/tree/main?tab=readme-ov-file#train)."
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Structured Outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can specify a JSON schema, [regular expression](https://en.wikipedia.org/wiki/Regular_expression) or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.\n",
"\n",
"SGLang supports three grammar backends:\n",
"\n",
"- [XGrammar](https://github.com/mlc-ai/xgrammar)(default): Supports JSON schema, regular expression, and EBNF constraints.\n",
"- [Outlines](https://github.com/dottxt-ai/outlines): Supports JSON schema and regular expression constraints.\n",
"- [Llguidance](https://github.com/guidance-ai/llguidance): Supports JSON schema, regular expression, and EBNF constraints.\n",
"\n",
"We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
"\n",
"To use Outlines, simply add `--grammar-backend outlines` when launching the server.\n",
"To use llguidance, add `--grammar-backend llguidance` when launching the server.\n",
"If no backend is specified, XGrammar will be used as the default.\n",
"\n",
"For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAI Compatible API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import os\n",
"\n",
"from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")\n",
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### JSON\n",
"\n",
"you can directly define a JSON schema or use [Pydantic](https://docs.pydantic.dev/latest/) to define and validate the response."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Using Pydantic**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"\n",
"\n",
"# Define the schema using Pydantic\n",
"class CapitalInfo(BaseModel):\n",
" name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n",
" population: int = Field(..., description=\"Population of the capital city\")\n",
"\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Please generate the information of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=128,\n",
" response_format={\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": {\n",
" \"name\": \"foo\",\n",
" # convert the pydantic model to json schema\n",
" \"schema\": CapitalInfo.model_json_schema(),\n",
" },\n",
" },\n",
")\n",
"\n",
"response_content = response.choices[0].message.content\n",
"# validate the JSON response by the pydantic model\n",
"capital_info = CapitalInfo.model_validate_json(response_content)\n",
"print_highlight(f\"Validated response: {capital_info.model_dump_json()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**JSON Schema Directly**\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"json_schema = json.dumps(\n",
" {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n",
" \"population\": {\"type\": \"integer\"},\n",
" },\n",
" \"required\": [\"name\", \"population\"],\n",
" }\n",
")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Give me the information of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=128,\n",
" response_format={\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n",
" },\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EBNF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ebnf_grammar = \"\"\"\n",
"root ::= city | description\n",
"city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\n",
"description ::= city \" is \" status\n",
"status ::= \"the capital of \" country\n",
"country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"\n",
"\"\"\"\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Give me the information of the capital of France.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=32,\n",
" extra_body={\"ebnf\": ebnf_grammar},\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regular expression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=128,\n",
" extra_body={\"regex\": \"(Paris|London)\"},\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structural Tag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tool_get_current_weather = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
" },\n",
" \"state\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
" \" in, e.g. 'CA' which would mean 'California'\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The unit to fetch the temperature in\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" },\n",
" },\n",
" \"required\": [\"city\", \"state\", \"unit\"],\n",
" },\n",
" },\n",
"}\n",
"\n",
"tool_get_current_date = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_date\",\n",
" \"description\": \"Get the current date and time for a given timezone\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"timezone\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The timezone to fetch the current date and time for, e.g. 'America/New_York'\",\n",
" }\n",
" },\n",
" \"required\": [\"timezone\"],\n",
" },\n",
" },\n",
"}\n",
"\n",
"schema_get_current_weather = tool_get_current_weather[\"function\"][\"parameters\"]\n",
"schema_get_current_date = tool_get_current_date[\"function\"][\"parameters\"]\n",
"\n",
"\n",
"def get_messages():\n",
" return [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": f\"\"\"\n",
"# Tool Instructions\n",
"- Always execute python code in messages that you share.\n",
"- When looking for real time information use relevant functions if available else fallback to brave_search\n",
"You have access to the following functions:\n",
"Use the function 'get_current_weather' to: Get the current weather in a given location\n",
"{tool_get_current_weather[\"function\"]}\n",
"Use the function 'get_current_date' to: Get the current date and time for a given timezone\n",
"{tool_get_current_date[\"function\"]}\n",
"If a you choose to call a function ONLY reply in the following format:\n",
"<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}\n",
"where\n",
"start_tag => `<function`\n",
"parameters => a JSON dict with the function argument name as key and function argument value as value.\n",
"end_tag => `</function>`\n",
"Here is an example,\n",
"<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>\n",
"Reminder:\n",
"- Function calls MUST follow the specified format\n",
"- Required parameters MUST be specified\n",
"- Only call one function at a time\n",
"- Put the entire function call reply on one line\n",
"- Always add your sources when using search results to answer the user query\n",
"You are a helpful assistant.\"\"\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" },\n",
" ]\n",
"\n",
"\n",
"messages = get_messages()\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=messages,\n",
" response_format={\n",
" \"type\": \"structural_tag\",\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" },\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Native API and SGLang Runtime (SRT)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### JSON"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Using Pydantic**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"import json\n",
"from pydantic import BaseModel, Field\n",
"\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
"\n",
"\n",
"# Define the schema using Pydantic\n",
"class CapitalInfo(BaseModel):\n",
" name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n",
" population: int = Field(..., description=\"Population of the capital city\")\n",
"\n",
"\n",
"# Make API request\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Here is the information of the capital of France in the JSON format.\\n\",\n",
" }\n",
"]\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"temperature\": 0,\n",
" \"max_new_tokens\": 64,\n",
" \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n",
" },\n",
" },\n",
")\n",
"print_highlight(response.json())\n",
"\n",
"\n",
"response_data = json.loads(response.json()[\"text\"])\n",
"# validate the response by the pydantic model\n",
"capital_info = CapitalInfo.model_validate(response_data)\n",
"print_highlight(f\"Validated response: {capital_info.model_dump_json()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**JSON Schema Directly**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"json_schema = json.dumps(\n",
" {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n",
" \"population\": {\"type\": \"integer\"},\n",
" },\n",
" \"required\": [\"name\", \"population\"],\n",
" }\n",
")\n",
"\n",
"# JSON\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"temperature\": 0,\n",
" \"max_new_tokens\": 64,\n",
" \"json_schema\": json_schema,\n",
" },\n",
" },\n",
")\n",
"\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EBNF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Give me the information of the capital of France.\",\n",
" }\n",
"]\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"max_new_tokens\": 128,\n",
" \"temperature\": 0,\n",
" \"n\": 3,\n",
" \"ebnf\": (\n",
" \"root ::= city | description\\n\"\n",
" 'city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\\n'\n",
" 'description ::= city \" is \" status\\n'\n",
" 'status ::= \"the capital of \" country\\n'\n",
" 'country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"'\n",
" ),\n",
" },\n",
" \"stream\": False,\n",
" \"return_logprob\": False,\n",
" },\n",
")\n",
"\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regular expression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Paris is the capital of\",\n",
" }\n",
"]\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"temperature\": 0,\n",
" \"max_new_tokens\": 64,\n",
" \"regex\": \"(France|England)\",\n",
" },\n",
" },\n",
")\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structural Tag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"# generate an answer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
"\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"payload = {\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" }\n",
" )\n",
" },\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"\n",
"llm = sgl.Engine(\n",
" model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", grammar_backend=\"xgrammar\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### JSON"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Using Pydantic**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pydantic import BaseModel, Field\n",
"\n",
"\n",
"prompts = [\n",
" \"Give me the information of the capital of China in the JSON format.\",\n",
" \"Give me the information of the capital of France in the JSON format.\",\n",
" \"Give me the information of the capital of Ireland in the JSON format.\",\n",
"]\n",
"\n",
"\n",
"# Define the schema using Pydantic\n",
"class CapitalInfo(BaseModel):\n",
" name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n",
" population: int = Field(..., description=\"Population of the capital city\")\n",
"\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.1,\n",
" \"top_p\": 0.95,\n",
" \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n",
"}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\") # validate the output by the pydantic model\n",
" capital_info = CapitalInfo.model_validate_json(output[\"text\"])\n",
" print_highlight(f\"Validated output: {capital_info.model_dump_json()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**JSON Schema Directly**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Give me the information of the capital of China in the JSON format.\",\n",
" \"Give me the information of the capital of France in the JSON format.\",\n",
" \"Give me the information of the capital of Ireland in the JSON format.\",\n",
"]\n",
"\n",
"json_schema = json.dumps(\n",
" {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n",
" \"population\": {\"type\": \"integer\"},\n",
" },\n",
" \"required\": [\"name\", \"population\"],\n",
" }\n",
")\n",
"\n",
"sampling_params = {\"temperature\": 0.1, \"top_p\": 0.95, \"json_schema\": json_schema}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EBNF\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Give me the information of the capital of France.\",\n",
" \"Give me the information of the capital of Germany.\",\n",
" \"Give me the information of the capital of Italy.\",\n",
"]\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.8,\n",
" \"top_p\": 0.95,\n",
" \"ebnf\": (\n",
" \"root ::= city | description\\n\"\n",
" 'city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\\n'\n",
" 'description ::= city \" is \" status\\n'\n",
" 'status ::= \"the capital of \" country\\n'\n",
" 'country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"'\n",
" ),\n",
"}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regular expression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Please provide information about London as a major global city:\",\n",
" \"Please provide information about Paris as a major global city:\",\n",
"]\n",
"\n",
"sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"regex\": \"(France|England)\"}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structural Tag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"prompts = [text]\n",
"\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.8,\n",
" \"top_p\": 0.95,\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" }\n",
" ),\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Structured Outputs For Reasoning Models\n",
"\n",
"When working with reasoning models that use special tokens like `<think>...</think>` to denote reasoning sections, you might want to allow free-form text within these sections while still enforcing grammar constraints on the rest of the output.\n",
"\n",
"SGLang provides a feature to disable grammar restrictions within reasoning sections. This is particularly useful for models that need to perform complex reasoning steps before providing a structured output.\n",
"\n",
"To enable this feature, use the `--reasoning-parser` flag which decide the think_end_token, such as `</think>`, when launching the server. You can also specify the reasoning parser using the `--reasoning-parser` flag.\n",
"\n",
"## Supported Models\n",
"\n",
"Currently, SGLang supports the following reasoning models:\n",
"- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `<think>` and `</think>` tags.\n",
"- [QwQ](https://huggingface.co/Qwen/QwQ-32B): The reasoning content is wrapped with `<think>` and `</think>` tags.\n",
"\n",
"\n",
"## Usage\n",
"\n",
"## OpenAI Compatible API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Specify the `--grammar-backend`, `--reasoning-parser` option."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import os\n",
"\n",
"from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")\n",
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### JSON\n",
"\n",
"you can directly define a JSON schema or use [Pydantic](https://docs.pydantic.dev/latest/) to define and validate the response."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Using Pydantic**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"\n",
"\n",
"# Define the schema using Pydantic\n",
"class CapitalInfo(BaseModel):\n",
" name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n",
" population: int = Field(..., description=\"Population of the capital city\")\n",
"\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=2048,\n",
" response_format={\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": {\n",
" \"name\": \"foo\",\n",
" # convert the pydantic model to json schema\n",
" \"schema\": CapitalInfo.model_json_schema(),\n",
" },\n",
" },\n",
")\n",
"\n",
"print_highlight(\n",
" f\"reasoing_content: {response.choices[0].message.reasoning_content}\\n\\ncontent: {response.choices[0].message.content}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**JSON Schema Directly**\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"json_schema = json.dumps(\n",
" {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n",
" \"population\": {\"type\": \"integer\"},\n",
" },\n",
" \"required\": [\"name\", \"population\"],\n",
" }\n",
")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=2048,\n",
" response_format={\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n",
" },\n",
")\n",
"\n",
"print_highlight(\n",
" f\"reasoing_content: {response.choices[0].message.reasoning_content}\\n\\ncontent: {response.choices[0].message.content}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EBNF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ebnf_grammar = \"\"\"\n",
"root ::= city | description\n",
"city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\n",
"description ::= city \" is \" status\n",
"status ::= \"the capital of \" country\n",
"country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"\n",
"\"\"\"\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
" max_tokens=2048,\n",
" extra_body={\"ebnf\": ebnf_grammar},\n",
")\n",
"\n",
"print_highlight(\n",
" f\"reasoing_content: {response.choices[0].message.reasoning_content}\\n\\ncontent: {response.choices[0].message.content}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regular expression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=[\n",
" {\"role\": \"assistant\", \"content\": \"What is the capital of France?\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=2048,\n",
" extra_body={\"regex\": \"(Paris|London)\"},\n",
")\n",
"\n",
"print_highlight(\n",
" f\"reasoing_content: {response.choices[0].message.reasoning_content}\\n\\ncontent: {response.choices[0].message.content}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structural Tag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tool_get_current_weather = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
" },\n",
" \"state\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
" \" in, e.g. 'CA' which would mean 'California'\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The unit to fetch the temperature in\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" },\n",
" },\n",
" \"required\": [\"city\", \"state\", \"unit\"],\n",
" },\n",
" },\n",
"}\n",
"\n",
"tool_get_current_date = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_date\",\n",
" \"description\": \"Get the current date and time for a given timezone\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"timezone\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The timezone to fetch the current date and time for, e.g. 'America/New_York'\",\n",
" }\n",
" },\n",
" \"required\": [\"timezone\"],\n",
" },\n",
" },\n",
"}\n",
"\n",
"schema_get_current_weather = tool_get_current_weather[\"function\"][\"parameters\"]\n",
"schema_get_current_date = tool_get_current_date[\"function\"][\"parameters\"]\n",
"\n",
"\n",
"def get_messages():\n",
" return [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": f\"\"\"\n",
"# Tool Instructions\n",
"- Always execute python code in messages that you share.\n",
"- When looking for real time information use relevant functions if available else fallback to brave_search\n",
"You have access to the following functions:\n",
"Use the function 'get_current_weather' to: Get the current weather in a given location\n",
"{tool_get_current_weather[\"function\"]}\n",
"Use the function 'get_current_date' to: Get the current date and time for a given timezone\n",
"{tool_get_current_date[\"function\"]}\n",
"If a you choose to call a function ONLY reply in the following format:\n",
"<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}\n",
"where\n",
"start_tag => `<function`\n",
"parameters => a JSON dict with the function argument name as key and function argument value as value.\n",
"end_tag => `</function>`\n",
"Here is an example,\n",
"<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>\n",
"Reminder:\n",
"- Function calls MUST follow the specified format\n",
"- Required parameters MUST be specified\n",
"- Only call one function at a time\n",
"- Put the entire function call reply on one line\n",
"- Always add your sources when using search results to answer the user query\n",
"You are a helpful assistant.\"\"\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n",
" },\n",
" ]\n",
"\n",
"\n",
"messages = get_messages()\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" messages=messages,\n",
" response_format={\n",
" \"type\": \"structural_tag\",\n",
" \"max_new_tokens\": 2048,\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" },\n",
")\n",
"\n",
"print_highlight(\n",
" f\"reasoing_content: {response.choices[0].message.reasoning_content}\\n\\ncontent: {response.choices[0].message.content}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Native API and SGLang Runtime (SRT)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### JSON"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Using Pydantic**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from pydantic import BaseModel, Field\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n",
"\n",
"\n",
"# Define the schema using Pydantic\n",
"class CapitalInfo(BaseModel):\n",
" name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n",
" population: int = Field(..., description=\"Population of the capital city\")\n",
"\n",
"\n",
"messages = [\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n",
" },\n",
"]\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"# Make API request\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"temperature\": 0,\n",
" \"max_new_tokens\": 2048,\n",
" \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n",
" },\n",
" },\n",
")\n",
"print(response.json())\n",
"\n",
"\n",
"reasoing_content = response.json()[\"text\"].split(\"</think>\")[0]\n",
"content = response.json()[\"text\"].split(\"</think>\")[1]\n",
"print_highlight(f\"reasoing_content: {reasoing_content}\\n\\ncontent: {content}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**JSON Schema Directly**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"json_schema = json.dumps(\n",
" {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n",
" \"population\": {\"type\": \"integer\"},\n",
" },\n",
" \"required\": [\"name\", \"population\"],\n",
" }\n",
")\n",
"\n",
"# JSON\n",
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"temperature\": 0,\n",
" \"max_new_tokens\": 2048,\n",
" \"json_schema\": json_schema,\n",
" },\n",
" },\n",
")\n",
"\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EBNF"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": \"Give me the information of the capital of France.\",\n",
" \"sampling_params\": {\n",
" \"max_new_tokens\": 2048,\n",
" \"temperature\": 0,\n",
" \"n\": 3,\n",
" \"ebnf\": (\n",
" \"root ::= city | description\\n\"\n",
" 'city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\\n'\n",
" 'description ::= city \" is \" status\\n'\n",
" 'status ::= \"the capital of \" country\\n'\n",
" 'country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"'\n",
" ),\n",
" },\n",
" \"stream\": False,\n",
" \"return_logprob\": False,\n",
" },\n",
")\n",
"\n",
"print(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regular expression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"text\": \"Paris is the capital of\",\n",
" \"sampling_params\": {\n",
" \"temperature\": 0,\n",
" \"max_new_tokens\": 2048,\n",
" \"regex\": \"(France|England)\",\n",
" },\n",
" },\n",
")\n",
"print(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Structural Tag"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"payload = {\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"max_new_tokens\": 2048,\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" }\n",
" ),\n",
" },\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"\n",
"llm = sgl.Engine(\n",
" model_path=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" reasoning_parser=\"deepseek-r1\",\n",
" grammar_backend=\"xgrammar\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### JSON"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Using Pydantic**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pydantic import BaseModel, Field\n",
"\n",
"\n",
"prompts = [\n",
" \"Give me the information of the capital of China in the JSON format.\",\n",
" \"Give me the information of the capital of France in the JSON format.\",\n",
" \"Give me the information of the capital of Ireland in the JSON format.\",\n",
"]\n",
"\n",
"\n",
"# Define the schema using Pydantic\n",
"class CapitalInfo(BaseModel):\n",
" name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n",
" population: int = Field(..., description=\"Population of the capital city\")\n",
"\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0,\n",
" \"top_p\": 0.95,\n",
" \"max_new_tokens\": 2048,\n",
" \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n",
"}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print(\"===============================\")\n",
" print(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**JSON Schema Directly**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Give me the information of the capital of China in the JSON format.\",\n",
" \"Give me the information of the capital of France in the JSON format.\",\n",
" \"Give me the information of the capital of Ireland in the JSON format.\",\n",
"]\n",
"\n",
"json_schema = json.dumps(\n",
" {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n",
" \"population\": {\"type\": \"integer\"},\n",
" },\n",
" \"required\": [\"name\", \"population\"],\n",
" }\n",
")\n",
"\n",
"sampling_params = {\"temperature\": 0, \"max_new_tokens\": 2048, \"json_schema\": json_schema}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print(\"===============================\")\n",
" print(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EBNF\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Give me the information of the capital of France.\",\n",
" \"Give me the information of the capital of Germany.\",\n",
" \"Give me the information of the capital of Italy.\",\n",
"]\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.8,\n",
" \"top_p\": 0.95,\n",
" \"ebnf\": (\n",
" \"root ::= city | description\\n\"\n",
" 'city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\\n'\n",
" 'description ::= city \" is \" status\\n'\n",
" 'status ::= \"the capital of \" country\\n'\n",
" 'country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"'\n",
" ),\n",
"}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print(\"===============================\")\n",
" print(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regular expression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Please provide information about London as a major global city:\",\n",
" \"Please provide information about Paris as a major global city:\",\n",
"]\n",
"\n",
"sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"regex\": \"(France|England)\"}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print(\"===============================\")\n",
" print(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = tokenizer.apply_chat_template(\n",
" messages, tokenize=False, add_generation_prompt=True\n",
")\n",
"prompts = [text]\n",
"\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.8,\n",
" \"top_p\": 0.95,\n",
" \"max_new_tokens\": 2048,\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"structures\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"schema\": schema_get_current_weather,\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"schema\": schema_get_current_date,\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"triggers\": [\"<function=\"],\n",
" }\n",
" ),\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print(\"===============================\")\n",
" print(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Query Vision Language Model"
]
},
{
"cell_type": "markdown",
"id": "1",
"metadata": {},
"source": [
"## Querying Qwen-VL"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply() # Run this first.\n",
"\n",
"model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n",
"chat_template = \"qwen2-vl\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"# Lets create a prompt.\n",
"\n",
"from io import BytesIO\n",
"import requests\n",
"from PIL import Image\n",
"\n",
"from sglang.srt.parser.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",
" BytesIO(\n",
" requests.get(\n",
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
" ).content\n",
" )\n",
")\n",
"\n",
"conv = chat_templates[chat_template].copy()\n",
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
"conv.append_message(conv.roles[1], \"\")\n",
"conv.image_data = [image]\n",
"\n",
"print(conv.get_prompt())\n",
"image"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"### Query via the offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"from sglang import Engine\n",
"\n",
"llm = Engine(\n",
" model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
"print(out[\"text\"])"
]
},
{
"cell_type": "markdown",
"id": "7",
"metadata": {},
"source": [
"### Query via the offline Engine API, but send precomputed embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"# Compute the image embeddings using Huggingface.\n",
"\n",
"from transformers import AutoProcessor\n",
"from transformers import Qwen2_5_VLForConditionalGeneration\n",
"\n",
"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
"vision = (\n",
" Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"processed_prompt = processor(\n",
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
")\n",
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
"precomputed_embeddings = vision(\n",
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
")\n",
"\n",
"mm_item = dict(\n",
" modality=\"IMAGE\",\n",
" image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
" precomputed_embeddings=precomputed_embeddings,\n",
")\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
"print(out[\"text\"])"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {},
"source": [
"## Querying Llama 4 (Vision)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply() # Run this first.\n",
"\n",
"model_path = \"meta-llama/Llama-4-Scout-17B-16E-Instruct\"\n",
"chat_template = \"llama-4\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"# Lets create a prompt.\n",
"\n",
"from io import BytesIO\n",
"import requests\n",
"from PIL import Image\n",
"\n",
"from sglang.srt.parser.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",
" BytesIO(\n",
" requests.get(\n",
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
" ).content\n",
" )\n",
")\n",
"\n",
"conv = chat_templates[chat_template].copy()\n",
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
"conv.append_message(conv.roles[1], \"\")\n",
"conv.image_data = [image]\n",
"\n",
"print(conv.get_prompt())\n",
"print(f\"Image size: {image.size}\")\n",
"\n",
"image"
]
},
{
"cell_type": "markdown",
"id": "13",
"metadata": {},
"source": [
"### Query via the offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [],
"source": [
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if not is_in_ci():\n",
" from sglang import Engine\n",
"\n",
" llm = Engine(\n",
" model_path=model_path,\n",
" trust_remote_code=True,\n",
" enable_multimodal=True,\n",
" mem_fraction_static=0.8,\n",
" tp_size=4,\n",
" attention_backend=\"fa3\",\n",
" context_length=65536,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {},
"outputs": [],
"source": [
"if not is_in_ci():\n",
" out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
" print(out[\"text\"])"
]
},
{
"cell_type": "markdown",
"id": "16",
"metadata": {},
"source": [
"### Query via the offline Engine API, but send precomputed embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"metadata": {},
"outputs": [],
"source": [
"if not is_in_ci():\n",
" # Compute the image embeddings using Huggingface.\n",
"\n",
" from transformers import AutoProcessor\n",
" from transformers import Llama4ForConditionalGeneration\n",
"\n",
" processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
" model = Llama4ForConditionalGeneration.from_pretrained(\n",
" model_path, torch_dtype=\"auto\"\n",
" ).eval()\n",
" vision = model.vision_model.cuda()\n",
" multi_modal_projector = model.multi_modal_projector.cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18",
"metadata": {},
"outputs": [],
"source": [
"if not is_in_ci():\n",
" processed_prompt = processor(\n",
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
" )\n",
" print(f'{processed_prompt[\"pixel_values\"].shape=}')\n",
" input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
"\n",
" image_outputs = vision(\n",
" processed_prompt[\"pixel_values\"].to(\"cuda\"), output_hidden_states=False\n",
" )\n",
" image_features = image_outputs.last_hidden_state\n",
" vision_flat = image_features.view(-1, image_features.size(-1))\n",
" precomputed_embeddings = multi_modal_projector(vision_flat)\n",
"\n",
" mm_item = dict(modality=\"IMAGE\", precomputed_embeddings=precomputed_embeddings)\n",
" out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
" print(out[\"text\"])"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"custom_cell_magics": "kql",
"encoding": "# -*- coding: utf-8 -*-"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment