speculative_decoding.ipynb 5.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Speculative Decoding\n",
    "\n",
    "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n",
    "\n",
11
12
    "**Note:** Currently, Speculative Decoding in SGLang does not support radix cache.\n",
    "\n",
13
14
    "### Performance Highlights\n",
    "\n",
15
16
17
18
    "- Official EAGLE code ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n",
    "- Standard SGLang Decoding: ~156 tokens/s\n",
    "- EAGLE Decoding in SGLang: ~297 tokens/s\n",
    "- EAGLE Decoding in SGLang (w/ `torch.compile`): ~316 tokens/s\n",
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    "\n",
    "All benchmarks below were run on a single H100."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EAGLE Decoding\n",
    "\n",
    "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# EAGLE decoding\n",
    "from sglang.utils import (\n",
    "    execute_shell_command,\n",
    "    wait_for_server,\n",
    "    terminate_process,\n",
    "    print_highlight,\n",
    ")\n",
    "\n",
    "server_process = execute_shell_command(\n",
    "    \"\"\"\n",
    "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algo EAGLE \\\n",
    "    --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
Yineng Zhang's avatar
Yineng Zhang committed
50
    "    --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --port=30020 --cuda-graph-max-bs 32\n",
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    "\"\"\"\n",
    ")\n",
    "\n",
    "wait_for_server(\"http://localhost:30020\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "\n",
    "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
    "\n",
    "response = client.chat.completions.create(\n",
    "    model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
    "    messages=[\n",
    "        {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
    "    ],\n",
    "    temperature=0,\n",
    "    max_tokens=64,\n",
    ")\n",
    "\n",
    "print_highlight(f\"Response: {response}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### EAGLE Decoding with `torch.compile`\n",
    "\n",
    "You can also enable `torch.compile` for further optimizations and optionally set `--cuda-graph-max-bs`:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "server_process = execute_shell_command(\n",
    "    \"\"\"\n",
    "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algo EAGLE \\\n",
    "    --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
    "        --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 \\\n",
    "            --enable-torch-compile --cuda-graph-max-bs 2 --port=30020\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "wait_for_server(\"http://localhost:30020\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark Script\n",
    "\n",
    "The following code example shows how to measure the decoding speed when generating tokens:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import requests\n",
    "\n",
    "tic = time.time()\n",
    "response = requests.post(\n",
    "    \"http://localhost:30020/generate\",\n",
    "    json={\n",
    "        \"text\": \"[INST] Give me a simple FastAPI server. Show the python code. [/INST]\",\n",
    "        \"sampling_params\": {\n",
    "            \"temperature\": 0,\n",
    "            \"max_new_tokens\": 256,\n",
    "        },\n",
    "    },\n",
    ")\n",
    "latency = time.time() - tic\n",
    "ret = response.json()\n",
    "completion_text = ret[\"text\"]\n",
    "speed = ret[\"meta_info\"][\"completion_tokens\"] / latency\n",
    "\n",
    "print_highlight(completion_text)\n",
    "print_highlight(f\"speed: {speed:.2f} token/s\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}