Unverified Commit 7fbab730 authored by Zhousx's avatar Zhousx Committed by GitHub
Browse files

[feat] add small vocab table for eagle's draft model[1]. (#3822)


Co-authored-by: default avatarAchazwl <323163497@qq.com>
Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent b7e274f2
...@@ -146,6 +146,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -146,6 +146,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `speculative_num_steps`: How many draft passes we run before verifying. * `speculative_num_steps`: How many draft passes we run before verifying.
* `speculative_num_draft_tokens`: The number of tokens proposed in a draft. * `speculative_num_draft_tokens`: The number of tokens proposed in a draft.
* `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1). * `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1).
* `speculative_token_map`: Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1).
## Double Sparsity ## Double Sparsity
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
"source": [ "source": [
"## EAGLE Decoding\n", "## EAGLE Decoding\n",
"\n", "\n",
"To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:" "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft-model-path`) and the relevant EAGLE parameters:"
] ]
}, },
{ {
...@@ -46,8 +46,8 @@ ...@@ -46,8 +46,8 @@
"\n", "\n",
"server_process, port = launch_server_cmd(\n", "server_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
...@@ -103,8 +103,8 @@ ...@@ -103,8 +103,8 @@
"source": [ "source": [
"server_process, port = launch_server_cmd(\n", "server_process, port = launch_server_cmd(\n",
" \"\"\"\n", " \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
" --enable-torch-compile --cuda-graph-max-bs 2\n", " --enable-torch-compile --cuda-graph-max-bs 2\n",
"\"\"\"\n", "\"\"\"\n",
...@@ -135,6 +135,77 @@ ...@@ -135,6 +135,77 @@
"print_highlight(f\"Response: {response}\")" "print_highlight(f\"Response: {response}\")"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n",
"\n",
"By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, checkout [the paper](https://arxiv.org/pdf/arXiv:2502.14856).\n",
"\n",
"In our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency token in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency token by directly downloading these token from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n",
"\n",
"Thanks for the contribution from [Weilin Zhao](https://github.com/https://github.com/Achazwl) and [Zhousx](https://github.com/Zhou-sx). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n",
" --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
")\n",
"\n",
"print_highlight(f\"Response: {response}\")"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
......
...@@ -116,6 +116,11 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): ...@@ -116,6 +116,11 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
# Llama 3.1 8B Instruct set tie_word_embeddings to False # Llama 3.1 8B Instruct set tie_word_embeddings to False
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else:
if hasattr(config, "hot_vocab_size"):
self.lm_head = ParallelLMHead(
config.hot_vocab_size, config.hidden_size, quant_config=quant_config
)
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
......
...@@ -128,6 +128,7 @@ class ServerArgs: ...@@ -128,6 +128,7 @@ class ServerArgs:
speculative_num_steps: int = 5 speculative_num_steps: int = 5
speculative_eagle_topk: int = 8 speculative_eagle_topk: int = 8
speculative_num_draft_tokens: int = 64 speculative_num_draft_tokens: int = 64
speculative_token_map: Optional[str] = None
# Double Sparsity # Double Sparsity
enable_double_sparsity: bool = False enable_double_sparsity: bool = False
...@@ -751,6 +752,12 @@ class ServerArgs: ...@@ -751,6 +752,12 @@ class ServerArgs:
help="The number of token sampled from draft model in Speculative Decoding.", help="The number of token sampled from draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_draft_tokens, default=ServerArgs.speculative_num_draft_tokens,
) )
parser.add_argument(
"--speculative-token-map",
type=str,
help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map,
)
# Double Sparsity # Double Sparsity
parser.add_argument( parser.add_argument(
......
import logging import logging
import os
import time import time
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from huggingface_hub import snapshot_download
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
...@@ -44,6 +46,23 @@ class EAGLEWorker(TpModelWorker): ...@@ -44,6 +46,23 @@ class EAGLEWorker(TpModelWorker):
# We will capture it later # We will capture it later
backup_disable_cuda_graph = server_args.disable_cuda_graph backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True server_args.disable_cuda_graph = True
if server_args.speculative_token_map is not None:
if os.path.exists(server_args.speculative_token_map):
self.hot_token_id = torch.load(server_args.speculative_token_map)
else:
cache_dir = snapshot_download(
os.path.dirname(server_args.speculative_token_map),
ignore_patterns=["*.bin", "*.safetensors"],
)
file_path = os.path.join(
cache_dir, os.path.basename(server_args.speculative_token_map)
)
self.hot_token_id = torch.load(file_path)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
super().__init__( super().__init__(
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
...@@ -66,7 +85,21 @@ class EAGLEWorker(TpModelWorker): ...@@ -66,7 +85,21 @@ class EAGLEWorker(TpModelWorker):
# Share the embedding and lm_head # Share the embedding and lm_head
if not self.speculative_algorithm.is_nextn(): if not self.speculative_algorithm.is_nextn():
embed, head = self.target_worker.model_runner.model.get_embed_and_head() embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if server_args.speculative_token_map is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
head.data = head.data[self.hot_token_id]
else:
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.model.set_embed_and_head(embed, head)
else:
if server_args.speculative_token_map is not None:
raise NotImplementedError(
"NEXTN does not support speculative-token-map now"
)
self.hot_token_id = None
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
...@@ -223,6 +256,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -223,6 +256,8 @@ class EAGLEWorker(TpModelWorker):
spec_info.topk_index, spec_info.topk_index,
spec_info.hidden_states, spec_info.hidden_states,
) )
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
# Return values # Return values
score_list: List[torch.Tensor] = [] score_list: List[torch.Tensor] = []
...@@ -262,6 +297,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -262,6 +297,8 @@ class EAGLEWorker(TpModelWorker):
) )
probs = torch.softmax(logits_output.next_token_logits, dim=-1) probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
hidden_states = logits_output.hidden_states hidden_states = logits_output.hidden_states
return score_list, token_list, parents_list return score_list, token_list, parents_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment