Unverified Commit 7443197a authored by Shi Shuai's avatar Shi Shuai Committed by GitHub
Browse files

[CI] Improve Docs CI Efficiency (#3587)


Co-authored-by: default avatarzhaochenyang20 <zhaochen20@outlook.com>
parent 862dd76c
...@@ -36,6 +36,9 @@ jobs: ...@@ -36,6 +36,9 @@ jobs:
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
apt-get update
apt-get install -y pandoc
apt-get update && apt-get install -y parallel
- name: Setup Jupyter Kernel - name: Setup Jupyter Kernel
run: | run: |
......
...@@ -8,6 +8,7 @@ on: ...@@ -8,6 +8,7 @@ on:
- "python/sglang/**" - "python/sglang/**"
- "test/**" - "test/**"
- "docs/**" - "docs/**"
- "scripts/**"
pull_request: pull_request:
branches: [ main ] branches: [ main ]
paths: paths:
...@@ -15,6 +16,7 @@ on: ...@@ -15,6 +16,7 @@ on:
- "python/sglang/**" - "python/sglang/**"
- "test/**" - "test/**"
- "docs/**" - "docs/**"
- "scripts/**"
workflow_dispatch: workflow_dispatch:
inputs: inputs:
version: version:
...@@ -45,6 +47,8 @@ jobs: ...@@ -45,6 +47,8 @@ jobs:
filters: | filters: |
docs: docs:
- 'docs/**' - 'docs/**'
scripts:
- 'scripts/**'
sglang: sglang:
- 'python/sglang/**' - 'python/sglang/**'
test: test:
......
...@@ -32,6 +32,7 @@ jobs: ...@@ -32,6 +32,7 @@ jobs:
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
apt-get update apt-get update
apt-get install -y pandoc apt-get install -y pandoc
apt-get update && apt-get install -y parallel
- name: Setup Jupyter Kernel - name: Setup Jupyter Kernel
run: | run: |
......
# Minimal makefile for Sphinx documentation # Minimal Makefile for Sphinx documentation
#
# You can set these variables from the terminal, and also
# from the environment for the first two.
SPHINXOPTS ?= SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build SPHINXBUILD ?= sphinx-build
SOURCEDIR = . SOURCEDIR = .
BUILDDIR = _build BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help: help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
# New target to compile Markdown and Jupyter Notebook files # Compile Notebook files and record execution time
compile: compile:
find $(SOURCEDIR) -path "*/_build/*" -prune -o -name "*.ipynb" -print | while read nb; do \ @set -e; \
if [ -f "$$nb" ]; then \ echo "Starting Notebook compilation..."; \
echo "Executing $$nb"; \ mkdir -p logs; \
jupyter nbconvert --to notebook --execute --inplace "$$nb" \ echo "Notebook execution timings:" > logs/timing.log; \
--ExecutePreprocessor.timeout=600 \ START_TOTAL=$$(date +%s); \
--ExecutePreprocessor.kernel_name=python3 || exit 1; \ find $(SOURCEDIR) -path "*/_build/*" -prune -o -name "*.ipynb" -print0 | \
fi; \ parallel -0 -j3 --halt soon,fail=1 ' \
done NB_NAME=$$(basename {}); \
START_TIME=$$(date +%s); \
jupyter nbconvert --to notebook --execute --inplace "{}" \
--ExecutePreprocessor.timeout=600 \
--ExecutePreprocessor.kernel_name=python3; \
RET_CODE=$$?; \
END_TIME=$$(date +%s); \
ELAPSED_TIME=$$((END_TIME - START_TIME)); \
echo "$${NB_NAME}: $${ELAPSED_TIME}s" >> logs/timing.log; \
exit $$RET_CODE' || exit 1; \
END_TOTAL=$$(date +%s); \
TOTAL_ELAPSED=$$((END_TOTAL - START_TOTAL)); \
echo "---------------------------------" >> logs/timing.log; \
echo "Total execution time: $${TOTAL_ELAPSED}s" >> logs/timing.log; \
echo "All Notebook execution timings:" && cat logs/timing.log
.PHONY: help Makefile compile .PHONY: help Makefile compile clean
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile %: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
clean: clean:
rm -rf $(BUILDDIR)/* rm -rf $(BUILDDIR)/* logs/timing.log
...@@ -31,17 +31,19 @@ ...@@ -31,17 +31,19 @@
"source": [ "source": [
"from openai import OpenAI\n", "from openai import OpenAI\n",
"import json\n", "import json\n",
"from sglang.utils import (\n", "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
" execute_shell_command,\n", "from sglang.test.test_utils import is_in_ci\n",
" wait_for_server,\n", "\n",
" terminate_process,\n", "if is_in_ci():\n",
" print_highlight,\n", " from patch import launch_server_cmd\n",
")\n", "else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n", " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --host 0.0.0.0\" # llama3\n",
")\n", ")\n",
"wait_for_server(\"http://localhost:30333\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -141,7 +143,7 @@ ...@@ -141,7 +143,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Initialize OpenAI-like client\n", "# Initialize OpenAI-like client\n",
"client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n", "client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n",
"model_name = client.models.list().data[0].id" "model_name = client.models.list().data[0].id"
] ]
}, },
...@@ -377,13 +379,13 @@ ...@@ -377,13 +379,13 @@
" tools=tools,\n", " tools=tools,\n",
")\n", ")\n",
"\n", "\n",
"gen_url = \"http://localhost:30333/generate\"\n", "gen_url = f\"http://localhost:{port}/generate\"\n",
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n", "gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
"print(gen_response)\n", "print(gen_response)\n",
"\n", "\n",
"# parse the response\n", "# parse the response\n",
"parse_url = \"http://localhost:30333/function_call\"\n", "parse_url = f\"http://localhost:{port}/function_call\"\n",
"\n", "\n",
"function_call_input = {\n", "function_call_input = {\n",
" \"text\": gen_response,\n", " \"text\": gen_response,\n",
...@@ -403,7 +405,7 @@ ...@@ -403,7 +405,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(server_process)" "terminate_process(server_process, port)"
] ]
}, },
{ {
......
...@@ -34,22 +34,22 @@ ...@@ -34,22 +34,22 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sglang.utils import (\n",
" execute_shell_command,\n",
" wait_for_server,\n",
" terminate_process,\n",
" print_highlight,\n",
")\n",
"\n",
"import requests\n", "import requests\n",
"from sglang.test.test_utils import is_in_ci\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "if is_in_ci():\n",
" \"\"\"\n", " from patch import launch_server_cmd\n",
"python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010\n", "else:\n",
"\"\"\"\n", " from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --host 0.0.0.0\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30010\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -66,7 +66,7 @@ ...@@ -66,7 +66,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"url = \"http://localhost:30010/generate\"\n", "url = f\"http://localhost:{port}/generate\"\n",
"data = {\"text\": \"What is the capital of France?\"}\n", "data = {\"text\": \"What is the capital of France?\"}\n",
"\n", "\n",
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
...@@ -92,7 +92,7 @@ ...@@ -92,7 +92,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"url = \"http://localhost:30010/get_model_info\"\n", "url = f\"http://localhost:{port}/get_model_info\"\n",
"\n", "\n",
"response = requests.get(url)\n", "response = requests.get(url)\n",
"response_json = response.json()\n", "response_json = response.json()\n",
...@@ -123,7 +123,7 @@ ...@@ -123,7 +123,7 @@
"source": [ "source": [
"# get_server_info\n", "# get_server_info\n",
"\n", "\n",
"url = \"http://localhost:30010/get_server_info\"\n", "url = f\"http://localhost:{port}/get_server_info\"\n",
"\n", "\n",
"response = requests.get(url)\n", "response = requests.get(url)\n",
"print_highlight(response.text)" "print_highlight(response.text)"
...@@ -144,7 +144,7 @@ ...@@ -144,7 +144,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"url = \"http://localhost:30010/health_generate\"\n", "url = f\"http://localhost:{port}/health_generate\"\n",
"\n", "\n",
"response = requests.get(url)\n", "response = requests.get(url)\n",
"print_highlight(response.text)" "print_highlight(response.text)"
...@@ -156,7 +156,7 @@ ...@@ -156,7 +156,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"url = \"http://localhost:30010/health\"\n", "url = f\"http://localhost:{port}/health\"\n",
"\n", "\n",
"response = requests.get(url)\n", "response = requests.get(url)\n",
"print_highlight(response.text)" "print_highlight(response.text)"
...@@ -179,7 +179,7 @@ ...@@ -179,7 +179,7 @@
"source": [ "source": [
"# flush cache\n", "# flush cache\n",
"\n", "\n",
"url = \"http://localhost:30010/flush_cache\"\n", "url = f\"http://localhost:{port}/flush_cache\"\n",
"\n", "\n",
"response = requests.post(url)\n", "response = requests.post(url)\n",
"print_highlight(response.text)" "print_highlight(response.text)"
...@@ -204,7 +204,7 @@ ...@@ -204,7 +204,7 @@
"source": [ "source": [
"# successful update with same architecture and size\n", "# successful update with same architecture and size\n",
"\n", "\n",
"url = \"http://localhost:30010/update_weights_from_disk\"\n", "url = f\"http://localhost:{port}/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n", "data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
"\n", "\n",
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
...@@ -222,7 +222,7 @@ ...@@ -222,7 +222,7 @@
"source": [ "source": [
"# failed update with different parameter size or wrong name\n", "# failed update with different parameter size or wrong name\n",
"\n", "\n",
"url = \"http://localhost:30010/update_weights_from_disk\"\n", "url = f\"http://localhost:{port}/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B-wrong\"}\n", "data = {\"model_path\": \"meta-llama/Llama-3.2-1B-wrong\"}\n",
"\n", "\n",
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
...@@ -252,16 +252,16 @@ ...@@ -252,16 +252,16 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(server_process)\n", "terminate_process(server_process, port)\n",
"\n", "\n",
"embedding_process = execute_shell_command(\n", "embedding_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n", "python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n",
" --port 30020 --host 0.0.0.0 --is-embedding\n", " --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30020\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -272,7 +272,7 @@ ...@@ -272,7 +272,7 @@
"source": [ "source": [
"# successful encode for embedding model\n", "# successful encode for embedding model\n",
"\n", "\n",
"url = \"http://localhost:30020/encode\"\n", "url = f\"http://localhost:{port}/encode\"\n",
"data = {\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"text\": \"Once upon a time\"}\n", "data = {\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"text\": \"Once upon a time\"}\n",
"\n", "\n",
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
...@@ -280,6 +280,15 @@ ...@@ -280,6 +280,15 @@
"print_highlight(f\"Text embedding (first 10): {response_json['embedding'][:10]}\")" "print_highlight(f\"Text embedding (first 10): {response_json['embedding'][:10]}\")"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(embedding_process, port)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
...@@ -295,18 +304,18 @@ ...@@ -295,18 +304,18 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(embedding_process)\n", "terminate_process(embedding_process, port)\n",
"\n", "\n",
"# Note that SGLang now treats embedding models and reward models as the same type of models.\n", "# Note that SGLang now treats embedding models and reward models as the same type of models.\n",
"# This will be updated in the future.\n", "# This will be updated in the future.\n",
"\n", "\n",
"reward_process = execute_shell_command(\n", "reward_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --port 30030 --host 0.0.0.0 --is-embedding\n", "python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30030\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -332,7 +341,7 @@ ...@@ -332,7 +341,7 @@
"tokenizer = AutoTokenizer.from_pretrained(\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\")\n",
"prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n", "prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n",
"\n", "\n",
"url = \"http://localhost:30030/classify\"\n", "url = f\"http://localhost:{port}/classify\"\n",
"data = {\"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \"text\": prompts}\n", "data = {\"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \"text\": prompts}\n",
"\n", "\n",
"responses = requests.post(url, json=data).json()\n", "responses = requests.post(url, json=data).json()\n",
...@@ -346,7 +355,7 @@ ...@@ -346,7 +355,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(reward_process)" "terminate_process(reward_process, port)"
] ]
}, },
{ {
...@@ -364,13 +373,13 @@ ...@@ -364,13 +373,13 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"tokenizer_free_server_process = execute_shell_command(\n", "tokenizer_free_server_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --skip-tokenizer-init\n", "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --skip-tokenizer-init\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30010\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -390,7 +399,7 @@ ...@@ -390,7 +399,7 @@
"print_highlight(f\"Tokenized Input: {input_tokens}\")\n", "print_highlight(f\"Tokenized Input: {input_tokens}\")\n",
"\n", "\n",
"response = requests.post(\n", "response = requests.post(\n",
" \"http://localhost:30010/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
" \"input_ids\": input_tokens,\n", " \"input_ids\": input_tokens,\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
...@@ -416,7 +425,7 @@ ...@@ -416,7 +425,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(tokenizer_free_server_process)" "terminate_process(tokenizer_free_server_process, port)"
] ]
} }
], ],
......
...@@ -40,6 +40,11 @@ ...@@ -40,6 +40,11 @@
"from sglang.utils import stream_and_merge, async_stream_and_merge\n", "from sglang.utils import stream_and_merge, async_stream_and_merge\n",
"import sglang as sgl\n", "import sglang as sgl\n",
"import asyncio\n", "import asyncio\n",
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" import patch\n",
"\n",
"\n", "\n",
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")" "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")"
] ]
...@@ -201,8 +206,6 @@ ...@@ -201,8 +206,6 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import sglang as sgl\n",
"\n",
"llm = sgl.Engine(\n", "llm = sgl.Engine(\n",
" model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n", " model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n",
")" ")"
......
...@@ -33,18 +33,22 @@ ...@@ -33,18 +33,22 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sglang.utils import (\n", "from sglang.test.test_utils import is_in_ci\n",
" execute_shell_command,\n", "\n",
" wait_for_server,\n", "if is_in_ci():\n",
" terminate_process,\n", " from patch import launch_server_cmd\n",
" print_highlight,\n", "else:\n",
")\n", " from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30020 --host 0.0.0.0\"\n", " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30020\")" "wait_for_server(f\"http://localhost:{port}\")\n",
"print(f\"Server started on http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -68,7 +72,7 @@ ...@@ -68,7 +72,7 @@
"source": [ "source": [
"import openai\n", "import openai\n",
"\n", "\n",
"client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
...@@ -245,7 +249,7 @@ ...@@ -245,7 +249,7 @@
"import time\n", "import time\n",
"from openai import OpenAI\n", "from openai import OpenAI\n",
"\n", "\n",
"client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "client = OpenAI(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"requests = [\n", "requests = [\n",
" {\n", " {\n",
...@@ -348,10 +352,10 @@ ...@@ -348,10 +352,10 @@
"import time\n", "import time\n",
"from openai import OpenAI\n", "from openai import OpenAI\n",
"\n", "\n",
"client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "client = OpenAI(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"requests = []\n", "requests = []\n",
"for i in range(100):\n", "for i in range(20):\n",
" requests.append(\n", " requests.append(\n",
" {\n", " {\n",
" \"custom_id\": f\"request-{i}\",\n", " \"custom_id\": f\"request-{i}\",\n",
...@@ -369,7 +373,7 @@ ...@@ -369,7 +373,7 @@
" \"content\": \"Write a detailed story about topic. Make it very long.\",\n", " \"content\": \"Write a detailed story about topic. Make it very long.\",\n",
" },\n", " },\n",
" ],\n", " ],\n",
" \"max_tokens\": 500,\n", " \"max_tokens\": 64,\n",
" },\n", " },\n",
" }\n", " }\n",
" )\n", " )\n",
...@@ -425,10 +429,10 @@ ...@@ -425,10 +429,10 @@
"from openai import OpenAI\n", "from openai import OpenAI\n",
"import os\n", "import os\n",
"\n", "\n",
"client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "client = OpenAI(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"requests = []\n", "requests = []\n",
"for i in range(500):\n", "for i in range(5000):\n",
" requests.append(\n", " requests.append(\n",
" {\n", " {\n",
" \"custom_id\": f\"request-{i}\",\n", " \"custom_id\": f\"request-{i}\",\n",
...@@ -446,7 +450,7 @@ ...@@ -446,7 +450,7 @@
" \"content\": \"Write a detailed story about topic. Make it very long.\",\n", " \"content\": \"Write a detailed story about topic. Make it very long.\",\n",
" },\n", " },\n",
" ],\n", " ],\n",
" \"max_tokens\": 500,\n", " \"max_tokens\": 128,\n",
" },\n", " },\n",
" }\n", " }\n",
" )\n", " )\n",
...@@ -508,7 +512,7 @@ ...@@ -508,7 +512,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(server_process)" "terminate_process(server_process, port)"
] ]
} }
], ],
......
...@@ -29,21 +29,23 @@ ...@@ -29,21 +29,23 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sglang.utils import (\n", "from sglang.test.test_utils import is_in_ci\n",
" execute_shell_command,\n", "\n",
" wait_for_server,\n", "if is_in_ci():\n",
" terminate_process,\n", " from patch import launch_server_cmd\n",
" print_highlight,\n", "else:\n",
")\n", " from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n", "\n",
"embedding_process = execute_shell_command(\n", "embedding_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n", "python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n",
" --port 30000 --host 0.0.0.0 --is-embedding\n", " --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30000\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -63,7 +65,7 @@ ...@@ -63,7 +65,7 @@
"\n", "\n",
"text = \"Once upon a time\"\n", "text = \"Once upon a time\"\n",
"\n", "\n",
"curl_text = f\"\"\"curl -s http://localhost:30000/v1/embeddings \\\n", "curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n", " -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
"\n", "\n",
"text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n", "text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n",
...@@ -91,7 +93,7 @@ ...@@ -91,7 +93,7 @@
"text = \"Once upon a time\"\n", "text = \"Once upon a time\"\n",
"\n", "\n",
"response = requests.post(\n", "response = requests.post(\n",
" \"http://localhost:30000/v1/embeddings\",\n", " f\"http://localhost:{port}/v1/embeddings\",\n",
" json={\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": text},\n", " json={\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": text},\n",
")\n", ")\n",
"\n", "\n",
...@@ -115,7 +117,7 @@ ...@@ -115,7 +117,7 @@
"source": [ "source": [
"import openai\n", "import openai\n",
"\n", "\n",
"client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"# Text embedding example\n", "# Text embedding example\n",
"response = client.embeddings.create(\n", "response = client.embeddings.create(\n",
...@@ -151,7 +153,7 @@ ...@@ -151,7 +153,7 @@
"tokenizer = AutoTokenizer.from_pretrained(\"Alibaba-NLP/gte-Qwen2-7B-instruct\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"Alibaba-NLP/gte-Qwen2-7B-instruct\")\n",
"input_ids = tokenizer.encode(text)\n", "input_ids = tokenizer.encode(text)\n",
"\n", "\n",
"curl_ids = f\"\"\"curl -s http://localhost:30000/v1/embeddings \\\n", "curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n", " -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
"\n", "\n",
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n", "input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",
...@@ -167,7 +169,7 @@ ...@@ -167,7 +169,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(embedding_process)" "terminate_process(embedding_process, port)"
] ]
} }
], ],
......
...@@ -34,21 +34,23 @@ ...@@ -34,21 +34,23 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sglang.utils import (\n", "from sglang.test.test_utils import is_in_ci\n",
" execute_shell_command,\n", "\n",
" wait_for_server,\n", "if is_in_ci():\n",
" terminate_process,\n", " from patch import launch_server_cmd\n",
" print_highlight,\n", "else:\n",
")\n", " from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n", "\n",
"embedding_process = execute_shell_command(\n", "embedding_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \\\n", "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \\\n",
" --port=30000 --chat-template=llama_3_vision\n", " --chat-template=llama_3_vision\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30000\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -68,32 +70,36 @@ ...@@ -68,32 +70,36 @@
"source": [ "source": [
"import subprocess\n", "import subprocess\n",
"\n", "\n",
"curl_command = \"\"\"\n", "curl_command = f\"\"\"\n",
"curl -s http://localhost:30000/v1/chat/completions \\\n", "curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
" -d '{\n", " -d '{{\n",
" \"model\": \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n", " \"model\": \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
" \"messages\": [\n", " \"messages\": [\n",
" {\n", " {{\n",
" \"role\": \"user\",\n", " \"role\": \"user\",\n",
" \"content\": [\n", " \"content\": [\n",
" {\n", " {{\n",
" \"type\": \"text\",\n", " \"type\": \"text\",\n",
" \"text\": \"What’s in this image?\"\n", " \"text\": \"What’s in this image?\"\n",
" },\n", " }},\n",
" {\n", " {{\n",
" \"type\": \"image_url\",\n", " \"type\": \"image_url\",\n",
" \"image_url\": {\n", " \"image_url\": {{\n",
" \"url\": \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n", " \"url\": \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
" }\n", " }}\n",
" }\n", " }}\n",
" ]\n", " ]\n",
" }\n", " }}\n",
" ],\n", " ],\n",
" \"max_tokens\": 300\n", " \"max_tokens\": 300\n",
" }'\n", " }}'\n",
"\"\"\"\n", "\"\"\"\n",
"\n", "\n",
"response = subprocess.check_output(curl_command, shell=True).decode()\n", "response = subprocess.check_output(curl_command, shell=True).decode()\n",
"print_highlight(response)\n",
"\n",
"\n",
"response = subprocess.check_output(curl_command, shell=True).decode()\n",
"print_highlight(response)" "print_highlight(response)"
] ]
}, },
...@@ -112,7 +118,7 @@ ...@@ -112,7 +118,7 @@
"source": [ "source": [
"import requests\n", "import requests\n",
"\n", "\n",
"url = \"http://localhost:30000/v1/chat/completions\"\n", "url = f\"http://localhost:{port}/v1/chat/completions\"\n",
"\n", "\n",
"data = {\n", "data = {\n",
" \"model\": \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n", " \"model\": \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
...@@ -152,7 +158,7 @@ ...@@ -152,7 +158,7 @@
"source": [ "source": [
"from openai import OpenAI\n", "from openai import OpenAI\n",
"\n", "\n",
"client = OpenAI(base_url=\"http://localhost:30000/v1\", api_key=\"None\")\n", "client = OpenAI(base_url=f\"http://localhost:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
" model=\"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n", " model=\"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
...@@ -196,7 +202,7 @@ ...@@ -196,7 +202,7 @@
"source": [ "source": [
"from openai import OpenAI\n", "from openai import OpenAI\n",
"\n", "\n",
"client = OpenAI(base_url=\"http://localhost:30000/v1\", api_key=\"None\")\n", "client = OpenAI(base_url=f\"http://localhost:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
" model=\"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n", " model=\"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n",
...@@ -236,7 +242,7 @@ ...@@ -236,7 +242,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(embedding_process)" "terminate_process(embedding_process, port)"
] ]
}, },
{ {
......
import os
from sglang.utils import execute_shell_command, reserve_port
DEFAULT_MAX_RUNNING_REQUESTS = 200
DEFAULT_MAX_TOTAL_TOKENS = 20480
import sglang.srt.server_args as server_args_mod
_original_post_init = server_args_mod.ServerArgs.__post_init__
def patched_post_init(self):
_original_post_init(self)
if self.max_running_requests is None:
self.max_running_requests = DEFAULT_MAX_RUNNING_REQUESTS
if self.max_total_tokens is None:
self.max_total_tokens = DEFAULT_MAX_TOTAL_TOKENS
self.disable_cuda_graph = True
server_args_mod.ServerArgs.__post_init__ = patched_post_init
def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
if port is None:
port = reserve_port()
extra_flags = (
f"--max-running-requests {DEFAULT_MAX_RUNNING_REQUESTS} "
f"--max-total-tokens {DEFAULT_MAX_TOTAL_TOKENS} "
f"--disable-cuda-graph"
)
full_command = f"{command} --port {port} {extra_flags}"
process = execute_shell_command(full_command)
return process, port
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
"\n", "\n",
"```bash\n", "```bash\n",
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
"--port 30000 --host 0.0.0.0\n", " --host 0.0.0.0\n",
"```\n", "```\n",
"\n", "\n",
"in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the [OpenAI-compatible APIs](https://platform.openai.com/docs/api-reference/chat)." "in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the [OpenAI-compatible APIs](https://platform.openai.com/docs/api-reference/chat)."
...@@ -34,21 +34,23 @@ ...@@ -34,21 +34,23 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sglang.utils import (\n", "from sglang.test.test_utils import is_in_ci\n",
" execute_shell_command,\n", "from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
" wait_for_server,\n", "\n",
" terminate_process,\n", "if is_in_ci():\n",
" print_highlight,\n", " from patch import launch_server_cmd\n",
")\n", "else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "server_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
"--port 30000 --host 0.0.0.0\n", " --host 0.0.0.0\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30000\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -66,9 +68,10 @@ ...@@ -66,9 +68,10 @@
"source": [ "source": [
"import subprocess, json\n", "import subprocess, json\n",
"\n", "\n",
"curl_command = \"\"\"\n", "curl_command = f\"\"\"\n",
"curl -s http://localhost:30000/v1/chat/completions \\\n", "curl -s http://localhost:{port}/v1/chat/completions \\\n",
" -d '{\"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}]}'\n", " -H \"Content-Type: application/json\" \\\n",
" -d '{{\"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"messages\": [{{\"role\": \"user\", \"content\": \"What is the capital of France?\"}}]}}'\n",
"\"\"\"\n", "\"\"\"\n",
"\n", "\n",
"response = json.loads(subprocess.check_output(curl_command, shell=True))\n", "response = json.loads(subprocess.check_output(curl_command, shell=True))\n",
...@@ -90,7 +93,7 @@ ...@@ -90,7 +93,7 @@
"source": [ "source": [
"import requests\n", "import requests\n",
"\n", "\n",
"url = \"http://localhost:30000/v1/chat/completions\"\n", "url = f\"http://localhost:{port}/v1/chat/completions\"\n",
"\n", "\n",
"data = {\n", "data = {\n",
" \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
...@@ -116,7 +119,7 @@ ...@@ -116,7 +119,7 @@
"source": [ "source": [
"import openai\n", "import openai\n",
"\n", "\n",
"client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
...@@ -144,7 +147,7 @@ ...@@ -144,7 +147,7 @@
"source": [ "source": [
"import openai\n", "import openai\n",
"\n", "\n",
"client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"# Use stream=True for streaming responses\n", "# Use stream=True for streaming responses\n",
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
...@@ -181,7 +184,7 @@ ...@@ -181,7 +184,7 @@
"import requests\n", "import requests\n",
"\n", "\n",
"response = requests.post(\n", "response = requests.post(\n",
" \"http://localhost:30000/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
" \"text\": \"The capital of France is\",\n", " \"text\": \"The capital of France is\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
...@@ -210,7 +213,7 @@ ...@@ -210,7 +213,7 @@
"import requests, json\n", "import requests, json\n",
"\n", "\n",
"response = requests.post(\n", "response = requests.post(\n",
" \"http://localhost:30000/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
" \"text\": \"The capital of France is\",\n", " \"text\": \"The capital of France is\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
...@@ -240,8 +243,15 @@ ...@@ -240,8 +243,15 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(server_process)" "terminate_process(server_process, port)"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
......
...@@ -35,23 +35,24 @@ ...@@ -35,23 +35,24 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# EAGLE decoding\n", "from sglang.test.test_utils import is_in_ci\n",
"from sglang.utils import (\n", "\n",
" execute_shell_command,\n", "if is_in_ci():\n",
" wait_for_server,\n", " from patch import launch_server_cmd\n",
" terminate_process,\n", "else:\n",
" print_highlight,\n", " from sglang.utils import launch_server_cmd\n",
")\n", "\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "server_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --port=30020 --cuda-graph-max-bs 32\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30020\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
}, },
{ {
...@@ -62,7 +63,7 @@ ...@@ -62,7 +63,7 @@
"source": [ "source": [
"import openai\n", "import openai\n",
"\n", "\n",
"client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n", "\n",
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
...@@ -100,25 +101,16 @@ ...@@ -100,25 +101,16 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"server_process = execute_shell_command(\n", "server_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
" --enable-torch-compile --cuda-graph-max-bs 2 --port=30020\n", " --enable-torch-compile --cuda-graph-max-bs 2\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30020\")" "wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark Script\n",
"\n",
"The following code example shows how to measure the decoding speed when generating tokens:\n"
] ]
}, },
{ {
...@@ -127,27 +119,20 @@ ...@@ -127,27 +119,20 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import time\n", "import openai\n",
"import requests\n", "\n",
"\n", "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"tic = time.time()\n", "\n",
"response = requests.post(\n", "response = client.chat.completions.create(\n",
" \"http://localhost:30020/generate\",\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" json={\n", " messages=[\n",
" \"text\": \"[INST] Give me a simple FastAPI server. Show the python code. [/INST]\",\n", " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" \"sampling_params\": {\n", " ],\n",
" \"temperature\": 0,\n", " temperature=0,\n",
" \"max_new_tokens\": 256,\n", " max_tokens=64,\n",
" },\n",
" },\n",
")\n", ")\n",
"latency = time.time() - tic\n",
"ret = response.json()\n",
"completion_text = ret[\"text\"]\n",
"speed = ret[\"meta_info\"][\"completion_tokens\"] / latency\n",
"\n", "\n",
"print_highlight(completion_text)\n", "print_highlight(f\"Response: {response}\")"
"print_highlight(f\"speed: {speed:.2f} token/s\")"
] ]
}, },
{ {
......
...@@ -38,24 +38,26 @@ ...@@ -38,24 +38,26 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sglang.utils import (\n",
" execute_shell_command,\n",
" wait_for_server,\n",
" terminate_process,\n",
" print_highlight,\n",
")\n",
"import openai\n", "import openai\n",
"import os\n", "import os\n",
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n", "\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n", "\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n", " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0 --grammar-backend xgrammar\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30000\")\n", "wait_for_server(f\"http://localhost:{port}\")\n",
"client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")" "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")"
] ]
}, },
{ {
...@@ -264,7 +266,7 @@ ...@@ -264,7 +266,7 @@
"\n", "\n",
"# Make API request\n", "# Make API request\n",
"response = requests.post(\n", "response = requests.post(\n",
" \"http://localhost:30000/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
" \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", " \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
...@@ -309,7 +311,7 @@ ...@@ -309,7 +311,7 @@
"\n", "\n",
"# JSON\n", "# JSON\n",
"response = requests.post(\n", "response = requests.post(\n",
" \"http://localhost:30000/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
" \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", " \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
...@@ -339,7 +341,7 @@ ...@@ -339,7 +341,7 @@
"import requests\n", "import requests\n",
"\n", "\n",
"response = requests.post(\n", "response = requests.post(\n",
" \"http://localhost:30000/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
" \"text\": \"Give me the information of the capital of France.\",\n", " \"text\": \"Give me the information of the capital of France.\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
...@@ -376,7 +378,7 @@ ...@@ -376,7 +378,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"response = requests.post(\n", "response = requests.post(\n",
" \"http://localhost:30000/generate\",\n", " f\"http://localhost:{port}/generate\",\n",
" json={\n", " json={\n",
" \"text\": \"Paris is the capital of\",\n", " \"text\": \"Paris is the capital of\",\n",
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
...@@ -395,7 +397,7 @@ ...@@ -395,7 +397,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"terminate_process(server_process)" "terminate_process(server_process, port)"
] ]
}, },
{ {
......
...@@ -16,13 +16,12 @@ The core features include: ...@@ -16,13 +16,12 @@ The core features include:
:caption: Getting Started :caption: Getting Started
start/install.md start/install.md
start/send_request.ipynb
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Backend Tutorial :caption: Backend Tutorial
backend/send_request.ipynb
backend/openai_api_completions.ipynb backend/openai_api_completions.ipynb
backend/openai_api_vision.ipynb backend/openai_api_vision.ipynb
backend/openai_api_embeddings.ipynb backend/openai_api_embeddings.ipynb
...@@ -33,7 +32,6 @@ The core features include: ...@@ -33,7 +32,6 @@ The core features include:
backend/function_calling.ipynb backend/function_calling.ipynb
backend/server_arguments.md backend/server_arguments.md
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Frontend Tutorial :caption: Frontend Tutorial
......
...@@ -306,22 +306,94 @@ def download_and_cache_file(url: str, filename: Optional[str] = None): ...@@ -306,22 +306,94 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
return filename return filename
def execute_shell_command(command: str) -> subprocess.Popen: import fcntl
LOCKFILE = "/tmp/port_lock"
PORT_REGISTRY = "/tmp/port_registry.json"
def print_highlight(html_content: str):
html_content = str(html_content).replace("\n", "<br>")
display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>"))
def init_port_registry():
"""Initialize the port registry file if it doesn't exist."""
if not os.path.exists(PORT_REGISTRY):
with open(PORT_REGISTRY, "w") as f:
json.dump([], f)
def reserve_port(start=30000, end=40000):
"""
Reserve an available port using a file lock and a registry.
Returns the allocated port.
""" """
Execute a shell command and return the process handle init_port_registry()
with open(LOCKFILE, "w") as lock:
fcntl.flock(lock, fcntl.LOCK_EX)
try:
with open(PORT_REGISTRY, "r") as f:
used = json.load(f)
except Exception:
used = []
for port in range(start, end):
if port not in used:
used.append(port)
with open(PORT_REGISTRY, "w") as f:
json.dump(used, f)
return port
raise RuntimeError("No free port available")
def release_port(port):
"""Release the reserved port by removing it from the registry."""
with open(LOCKFILE, "w") as lock:
fcntl.flock(lock, fcntl.LOCK_EX)
try:
with open(PORT_REGISTRY, "r") as f:
used = json.load(f)
except Exception:
used = []
if port in used:
used.remove(port)
with open(PORT_REGISTRY, "w") as f:
json.dump(used, f)
Args:
command: Shell command as a string (can include \\ line continuations) def execute_shell_command(command: str) -> subprocess.Popen:
Returns: """
subprocess.Popen: Process handle Execute a shell command and return its process handle.
""" """
# Replace \ newline with space and split # Replace newline continuations and split the command string.
command = command.replace("\\\n", " ").replace("\\", " ") command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split() parts = command.split()
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT) return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT)
def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
if port is None:
port = reserve_port()
full_command = f"{command} --port {port}"
process = execute_shell_command(full_command)
return process, port
def terminate_process(process, port=None):
"""
Terminate the process and, if a port was reserved, release it.
"""
from sglang.srt.utils import kill_process_tree
kill_process_tree(process.pid)
if port is not None:
release_port(port)
def wait_for_server(base_url: str, timeout: int = None) -> None: def wait_for_server(base_url: str, timeout: int = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint. """Wait for the server to be ready by polling the /v1/models endpoint.
...@@ -343,6 +415,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: ...@@ -343,6 +415,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
NOTE: Typically, the server runs in a separate terminal. NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined. In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue. To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
""" """
) )
break break
...@@ -353,17 +426,6 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: ...@@ -353,17 +426,6 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
time.sleep(1) time.sleep(1)
def terminate_process(process):
from sglang.srt.utils import kill_process_tree
kill_process_tree(process.pid)
def print_highlight(html_content: str):
html_content = str(html_content).replace("\n", "<br>")
display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>"))
class TypeBasedDispatcher: class TypeBasedDispatcher:
def __init__(self, mapping: List[Tuple[Type, Callable]]): def __init__(self, mapping: List[Tuple[Type, Callable]]):
self._mapping = mapping self._mapping = mapping
......
...@@ -3,6 +3,7 @@ set -euxo pipefail ...@@ -3,6 +3,7 @@ set -euxo pipefail
# Install the dependency in CI. # Install the dependency in CI.
# Use repo from environment variable, passed from GitHub Actions # Use repo from environment variable, passed from GitHub Actions
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}" FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
......
...@@ -15,7 +15,7 @@ from safetensors.torch import save_file ...@@ -15,7 +15,7 @@ from safetensors.torch import save_file
from transformers import AutoConfig from transformers import AutoConfig
def get_nexn_layer_id(config): def get_nextn_layer_id(config):
if not hasattr(config, "num_hidden_layers"): if not hasattr(config, "num_hidden_layers"):
raise ValueError("'num_hidden_layers' not found in model config.") raise ValueError("'num_hidden_layers' not found in model config.")
return config.num_hidden_layers return config.num_hidden_layers
...@@ -25,7 +25,7 @@ def update_and_save_config(config, output_dir): ...@@ -25,7 +25,7 @@ def update_and_save_config(config, output_dir):
new_config = config.to_dict() new_config = config.to_dict()
new_config.update( new_config.update(
{ {
"num_hidden_layers": 0, "num_hidden_layers": 1,
"architectures": ["DeepseekV3ForCausalLMNextN"], "architectures": ["DeepseekV3ForCausalLMNextN"],
} }
) )
...@@ -42,8 +42,8 @@ def copy_non_safetensors_files(input_dir, output_dir): ...@@ -42,8 +42,8 @@ def copy_non_safetensors_files(input_dir, output_dir):
print(f"All non-safetensors files have been copied to {output_dir}") print(f"All non-safetensors files have been copied to {output_dir}")
def export_nextn_layer_parameters(input_dir, output_dir, nexn_layer_id): def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
prefix = f"model.layers.{nexn_layer_id}" prefix = f"model.layers.{nextn_layer_id}"
output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors") output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors")
params = {} params = {}
for filename in os.listdir(input_dir): for filename in os.listdir(input_dir):
...@@ -106,7 +106,7 @@ if __name__ == "__main__": ...@@ -106,7 +106,7 @@ if __name__ == "__main__":
config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True) config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True)
assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported." assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported."
nextn_layer_id = get_nexn_layer_id(config) nextn_layer_id = get_nextn_layer_id(config)
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
copy_non_safetensors_files(args.input_dir, args.output_dir) copy_non_safetensors_files(args.input_dir, args.output_dir)
update_and_save_config(config, args.output_dir) update_and_save_config(config, args.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment