{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Speculative Decoding\n", "\n", "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n", "\n", "To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n", "> ```bash\n", "> pip install cutex\n", "> ```\n", "\n", "### Performance Highlights\n", "\n", "- Official EAGLE code ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n", "- Standard SGLang Decoding: ~156 tokens/s\n", "- EAGLE Decoding in SGLang: ~297 tokens/s\n", "- EAGLE Decoding in SGLang (w/ `torch.compile`): ~316 tokens/s\n", "\n", "All benchmarks below were run on a single H100." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## EAGLE Decoding\n", "\n", "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# EAGLE decoding\n", "from sglang.utils import (\n", " execute_shell_command,\n", " wait_for_server,\n", " terminate_process,\n", " print_highlight,\n", ")\n", "\n", "server_process = execute_shell_command(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --port=30020\n", "\"\"\"\n", ")\n", "\n", "wait_for_server(\"http://localhost:30020\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import openai\n", "\n", "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "response = client.chat.completions.create(\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " messages=[\n", " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", " ],\n", " temperature=0,\n", " max_tokens=64,\n", ")\n", "\n", "print_highlight(f\"Response: {response}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "terminate_process(server_process)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### EAGLE Decoding with `torch.compile`\n", "\n", "You can also enable `torch.compile` for further optimizations and optionally set `--cuda-graph-max-bs`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "server_process = execute_shell_command(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 \\\n", " --enable-torch-compile --cuda-graph-max-bs 2 --port=30020\n", "\"\"\"\n", ")\n", "\n", "wait_for_server(\"http://localhost:30020\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmark Script\n", "\n", "The following code example shows how to measure the decoding speed when generating tokens:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "import requests\n", "\n", "tic = time.time()\n", "response = requests.post(\n", " \"http://localhost:30020/generate\",\n", " json={\n", " \"text\": \"[INST] Give me a simple FastAPI server. Show the python code. [/INST]\",\n", " \"sampling_params\": {\n", " \"temperature\": 0,\n", " \"max_new_tokens\": 256,\n", " },\n", " },\n", ")\n", "latency = time.time() - tic\n", "ret = response.json()\n", "completion_text = ret[\"text\"]\n", "speed = ret[\"meta_info\"][\"completion_tokens\"] / latency\n", "\n", "print_highlight(completion_text)\n", "print_highlight(f\"speed: {speed:.2f} token/s\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "terminate_process(server_process)" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 2 }