Unverified Commit 7fbab730 authored by Zhousx's avatar Zhousx Committed by GitHub
Browse files

[feat] add small vocab table for eagle's draft model[1]. (#3822)


Co-authored-by: default avatarAchazwl <323163497@qq.com>
Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent b7e274f2
......@@ -146,6 +146,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `speculative_num_steps`: How many draft passes we run before verifying.
* `speculative_num_draft_tokens`: The number of tokens proposed in a draft.
* `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1).
* `speculative_token_map`: Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1).
## Double Sparsity
......
......@@ -26,7 +26,7 @@
"source": [
"## EAGLE Decoding\n",
"\n",
"To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:"
"To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft-model-path`) and the relevant EAGLE parameters:"
]
},
{
......@@ -46,8 +46,8 @@
"\n",
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n",
"\"\"\"\n",
")\n",
......@@ -103,8 +103,8 @@
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
" --enable-torch-compile --cuda-graph-max-bs 2\n",
"\"\"\"\n",
......@@ -135,6 +135,77 @@
"print_highlight(f\"Response: {response}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n",
"\n",
"By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, checkout [the paper](https://arxiv.org/pdf/arXiv:2502.14856).\n",
"\n",
"In our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency token in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency token by directly downloading these token from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n",
"\n",
"Thanks for the contribution from [Weilin Zhao](https://github.com/https://github.com/Achazwl) and [Zhousx](https://github.com/Zhou-sx). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n",
" --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
")\n",
"\n",
"print_highlight(f\"Response: {response}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
......
......@@ -117,9 +117,14 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if hasattr(config, "hot_vocab_size"):
self.lm_head = ParallelLMHead(
config.hot_vocab_size, config.hidden_size, quant_config=quant_config
)
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -128,6 +128,7 @@ class ServerArgs:
speculative_num_steps: int = 5
speculative_eagle_topk: int = 8
speculative_num_draft_tokens: int = 64
speculative_token_map: Optional[str] = None
# Double Sparsity
enable_double_sparsity: bool = False
......@@ -751,6 +752,12 @@ class ServerArgs:
help="The number of token sampled from draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_draft_tokens,
)
parser.add_argument(
"--speculative-token-map",
type=str,
help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map,
)
# Double Sparsity
parser.add_argument(
......
import logging
import os
import time
from typing import List, Optional, Union
import torch
from huggingface_hub import snapshot_download
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
......@@ -44,6 +46,23 @@ class EAGLEWorker(TpModelWorker):
# We will capture it later
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
if server_args.speculative_token_map is not None:
if os.path.exists(server_args.speculative_token_map):
self.hot_token_id = torch.load(server_args.speculative_token_map)
else:
cache_dir = snapshot_download(
os.path.dirname(server_args.speculative_token_map),
ignore_patterns=["*.bin", "*.safetensors"],
)
file_path = os.path.join(
cache_dir, os.path.basename(server_args.speculative_token_map)
)
self.hot_token_id = torch.load(file_path)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
......@@ -66,7 +85,21 @@ class EAGLEWorker(TpModelWorker):
# Share the embedding and lm_head
if not self.speculative_algorithm.is_nextn():
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if server_args.speculative_token_map is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
head.data = head.data[self.hot_token_id]
else:
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
else:
if server_args.speculative_token_map is not None:
raise NotImplementedError(
"NEXTN does not support speculative-token-map now"
)
self.hot_token_id = None
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners
......@@ -223,6 +256,8 @@ class EAGLEWorker(TpModelWorker):
spec_info.topk_index,
spec_info.hidden_states,
)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
# Return values
score_list: List[torch.Tensor] = []
......@@ -262,6 +297,8 @@ class EAGLEWorker(TpModelWorker):
)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
hidden_states = logits_output.hidden_states
return score_list, token_list, parents_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment