Unverified Commit 815dce05 authored by yukavio's avatar yukavio Committed by GitHub
Browse files

Eagle speculative decoding part 4: Add EAGLE2 worker (#2150)


Co-authored-by: default avatarkavioyu <kavioyu@tencent.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent ad20b795
import sglang as sgl
def main():
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = {"temperature": 0, "max_new_tokens": 30}
# Create an LLM.
llm = sgl.Engine(
model_path="meta-llama/Llama-2-7b-chat-hf",
speculative_algorithm="EAGLE",
speculative_draft_model_path="lmzheng/sglang-EAGLE-llama2-chat-7B",
speculative_num_steps=3,
speculative_eagle_topk=4,
speculative_num_draft_tokens=16,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
import cutex
import torch
# parent_table [bs,topk*depth+)]
# selected_index [bs,draft_token_num-1)]
# verified_seq_len [bs]
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
# positions [bs*draft_token]
# retrive_index [b, draft_token, depth+2]
kernels = cutex.SourceModule(
"""
//cuda
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num){
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for(int i=0; i<bid; i++){
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
for(int i=0; i<draft_token_num-1; i++){
tree_mask[token_tree_idx+i] = false;
}
int position = 0;
if (tid==0){
positions[bid*draft_token_num] = seq_len;
retrive_index[bid][0][0] = bid * draft_token_num;
return;
}
int depends_order[10];
int cur_position = tid-1;
while(true){
depends_order[position] = cur_position+1;
position += 1;
tree_mask[token_tree_idx+cur_position] = true;
int parent_tb_idx = selected_index[bid][cur_position]/topk;
if(parent_tb_idx==0){
break;
}
int token_idx = parent_list[bid][parent_tb_idx];
for(cur_position=0; cur_position<draft_token_num;cur_position++){
if(selected_index[bid][cur_position]==token_idx){
break;
}
}
}
positions[bid*draft_token_num+tid] = position + seq_len;
int is_leaf = 0;
for(int i=1;i<draft_token_num;i++){
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
{
is_leaf ++;
}
}
if(is_leaf==1){
for(int i=0; i<position; i++){
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
}
retrive_index[bid][tid][0] = bid*draft_token_num;
}
}
//!cuda
""",
float_bits=16, # change to 16 to use half precision as `float` type in the above source code.
boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on.
)
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
bs = seq_lens.numel()
device = parent_list.device
tree_mask = torch.full(
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
True,
device=device,
)
retrive_index = torch.full(
(bs, draft_token, depth + 2), -1, device=device, dtype=torch.long
)
positions = torch.empty((bs * draft_token,), device=device, dtype=torch.long)
kernels.build_tree(
parent_list,
top_score_index,
seq_lens.to(torch.int32),
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token,
grid=(bs, 1, 1),
block=(64, 1, 1),
)
index = retrive_index.sum(dim=-1) != -depth - 2
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
retrive_cum_len = torch.zeros(
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
)
retrive_cum_len[1:] = cum_len
retrive_index = retrive_index[index]
return tree_mask, positions, retrive_index, retrive_cum_len
if __name__ == "__main__":
def findp(p_i, index, parent_list):
pos = index // 10
index_list = index.tolist()
parent_list = parent_list.tolist()
res = [p_i]
while True:
p = pos[p_i]
if p == 0:
break
token_idx = parent_list[p]
p_i = index_list.index(token_idx)
res.append(p_i)
return res
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
mask = []
positions = []
retrive_index = []
for i, lens in enumerate(seq_len.tolist()):
first_mask = torch.full((lens + draft_token,), True)
first_mask[-(draft_token - 1) :] = False
positions.append(lens)
mask.append(first_mask)
seq_order = []
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
r_index = [first_index]
for j in range(draft_token - 1):
mask.append(torch.full((lens + 1,), True))
idx = findp(j, index, parent_list)
seq_order.append(idx)
positions.append(len(idx) + seq_len)
t = torch.full((draft_token - 1,), False)
t[idx] = True
mask.append(t)
for i in range(1, draft_token - 1):
is_leaf = 0
for j in range(draft_token - 1):
if i in seq_order[j]:
is_leaf += 1
if is_leaf == 1:
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
for _ in range(max_depth + 1 - len(seq_order[i])):
order_list.append(-1)
order = torch.Tensor(order_list).cuda().to(torch.long)
r_index.append(order)
retrive_index.append(torch.stack(r_index))
return (
torch.cat(mask).cuda(),
torch.Tensor(positions).cuda().to(torch.long),
torch.stack(retrive_index),
)
index = (
torch.Tensor(
[
0,
1,
2,
3,
10,
11,
12,
13,
20,
21,
22,
30,
110,
130,
150,
160,
210,
211,
212,
213,
214,
215,
216,
217,
218,
219,
220,
230,
310,
311,
312,
313,
314,
315,
316,
317,
320,
321,
322,
330,
360,
380,
390,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
423,
430,
431,
440,
441,
460,
470,
]
)
.to(torch.long)
.cuda()
)
parent_list = (
torch.Tensor(
[
-1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
20,
30,
21,
13,
22,
40,
23,
110,
130,
160,
150,
190,
120,
111,
121,
200,
180,
210,
211,
212,
213,
214,
215,
216,
220,
230,
217,
310,
311,
312,
313,
320,
314,
321,
315,
316,
317,
]
)
.to(torch.long)
.cuda()
)
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
bs = verified_seq_len.shape[0]
topk = 10
depth = 5 # depth <= 10
draft_token = 64
tree_mask = torch.full(
(
torch.sum(verified_seq_len).item() * draft_token
+ draft_token * draft_token * bs,
),
True,
).cuda()
retrive_index = torch.full(
(bs, draft_token, depth + 2), -1, device="cuda", dtype=torch.long
)
positions = torch.empty((bs * draft_token,), device="cuda", dtype=torch.long)
kernels.build_tree(
parent_list.unsqueeze(0),
index.unsqueeze(0),
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token,
grid=(bs, 1, 1),
block=(64, 1, 1),
)
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
c_mask, c_positions, c_retive_index = create_mask(
verified_seq_len, draft_token, index, parent_list, depth
)
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
assert torch.allclose(positions, c_positions), "positions has error."
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
This diff is collapsed.
from typing import List, Optional, Union
import torch
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
class EAGLEWorker(TpModelWorker):
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
target_worker: TpModelWorker,
):
# Do not capture cuda graph in `super().__init__()`
# We will capture it later
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=nccl_port,
dp_rank=dp_rank,
is_draft_worker=True,
)
self.target_worker = target_worker
self.server_args = server_args
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
self.model_runner.init_cuda_graphs()
def forward_draft_decode(self, batch: ScheduleBatch):
batch.spec_info.prepare_for_decode(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
def forward_draft_extend(self, batch: ScheduleBatch):
self._swap_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._swap_mem_pool(batch, self.target_worker.model_runner)
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
if batch.forward_mode.is_decode():
prev_spec_info = batch.spec_info
self._swap_mem_pool(batch, self.model_runner)
for i in range(self.server_args.speculative_num_steps):
self.forward_draft_decode(batch)
batch.spec_info.clear_draft_cache(batch)
self._swap_mem_pool(batch, self.target_worker.model_runner)
(
next_draft_input,
logits_output,
verified_id,
self.finish_extend_len,
model_worker_batch,
) = self.verify(batch)
next_draft_input.init(self.server_args)
batch.spec_info = next_draft_input
# if it is None, means all requsets are finished
if batch.spec_info.verified_id is not None:
self.forward_extend_after_decode(batch)
batch.spec_info = prev_spec_info
return logits_output, verified_id, model_worker_batch, next_draft_input
else:
spec_info = EAGLEDraftInput()
spec_info.init(self.server_args)
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_info = spec_info
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch
)
model_worker_batch.spec_info.verified_id = next_token_ids
model_worker_batch.spec_info.hidden_states = logits_output.hidden_states
batch.spec_info = spec_info
self.forward_draft_extend(batch)
batch.spec_info = None
return logits_output, next_token_ids, model_worker_batch, spec_info
def verify(self, batch: ScheduleBatch):
verify_input = batch.spec_info.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
verify_input.prepare_for_verify(batch)
batch.spec_info = verify_input
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch = batch.get_model_worker_batch()
logits_output, _ = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
verify_input.hidden_states = logits_output.hidden_states
res = verify_input.verify(batch, logits_output)
batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,)
def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.token_to_kv_pool = runner.token_to_kv_pool
batch.req_to_token_pool = runner.req_to_token_pool
def forward_extend_after_decode(self, batch: ScheduleBatch):
self._swap_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
if batch.spec_info.has_finished:
index = batch.spec_info.unfinished_index
seq_lens = batch.seq_lens
batch.seq_lens = batch.seq_lens[index]
batch.spec_info.prepare_extend_after_decode(batch)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
batch.spec_info.hidden_states = logits_output.hidden_states
self.capture_for_decode(logits_output, forward_batch)
batch.forward_mode = ForwardMode.DECODE
if batch.spec_info.has_finished:
batch.seq_lens = seq_lens
self._swap_mem_pool(batch, self.target_worker.model_runner)
def capture_for_decode(self, logits_output, forward_batch):
if isinstance(logits_output, LogitsProcessorOutput):
logits = logits_output.next_token_logits
sample_output = torch.softmax(
logits, dim=-1
) # TODO: Support more sampling method @kavioyu
forward_batch.spec_info.capture_for_decode(
sample_output, logits_output.hidden_states, forward_batch.forward_mode
)
# Don't support prefix share now.
def finish_request(self, reqs: Union[Req, List[Req]]):
if not isinstance(reqs, List):
reqs = [reqs]
for req in reqs:
req_len = (
len(req.origin_input_ids)
+ len(req.output_ids)
- self.finish_extend_len[req.rid]
- 1
)
kv_indices = self.model_runner.req_to_token_pool.req_to_token[
req.req_pool_idx
][:req_len]
self.model_runner.token_to_kv_pool.free(kv_indices)
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
...@@ -13,6 +13,7 @@ suites = { ...@@ -13,6 +13,7 @@ suites = {
"test_abort.py", "test_abort.py",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_double_sparsity.py", "test_double_sparsity.py",
"test_eagle_infer.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
"test_get_weights_by_name.py", "test_get_weights_by_name.py",
......
import unittest
import sglang as sgl
class TestEAGLEEngine(unittest.TestCase):
def test_eagle_accuracy(self):
prompt = "Today is a sunny day and I like"
target_model_path = "meta-llama/Llama-2-7b-chat-hf"
speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
engine = sgl.Engine(
model_path=target_model_path,
speculative_draft_model_path=speculative_draft_model_path,
speculative_algorithm="EAGLE",
speculative_num_steps=3,
speculative_eagle_topk=4,
speculative_num_draft_tokens=16,
)
out1 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()
engine = sgl.Engine(model_path=target_model_path)
out2 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()
print("==== Answer 1 ====")
print(out1)
print("==== Answer 2 ====")
print(out2)
self.assertEqual(out1, out2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment