Unverified Commit 1b859295 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Eagle] Remove the greedy branch and some redundant code (#4363)


Co-authored-by: default avatarSehoon Kim <sehoon@x.ai>
parent 9971dc22
...@@ -43,7 +43,7 @@ runtime_common = [ ...@@ -43,7 +43,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.0.5.post1", "sgl-kernel==0.0.5.post2",
"flashinfer_python==0.2.3", "flashinfer_python==0.2.3",
"torch==2.5.1", "torch==2.5.1",
"vllm>=0.6.4.post1,<=0.7.2", "vllm>=0.6.4.post1,<=0.7.2",
......
...@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.post("/flush_cache") @app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache(): async def flush_cache():
"""Flush the radix cache.""" """Flush the radix cache."""
_global_state.tokenizer_manager.flush_cache() _global_state.tokenizer_manager.flush_cache()
......
...@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"largest-len: {self._largest_prefill_decode_len}, "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
) )
spec_accept_length = 0 spec_accept_length = 0
...@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {spec_accept_length:.2f}, " f"accept len: {spec_accept_length:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"largest-len: {self._largest_prefill_decode_len}, "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
) )
......
...@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else: else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else: else:
capture_bs = list(range(1, 33)) # Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
if _is_hip: if _is_hip:
capture_bs += [i * 8 for i in range(21, 33)] capture_bs += [i * 8 for i in range(21, 33)]
...@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if max(capture_bs) > model_runner.req_to_token_pool.size: if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs. # is very small. We add more values here to make sure we capture the maximum bs.
capture_bs = list( capture_bs += [model_runner.req_to_token_pool.size - 1] + [
sorted( model_runner.req_to_token_pool.size
set( ]
capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
capture_bs = list(sorted(set(capture_bs)))
capture_bs = [ capture_bs = [
bs bs
for bs in capture_bs for bs in capture_bs
...@@ -508,7 +505,9 @@ class CudaGraphRunner: ...@@ -508,7 +505,9 @@ class CudaGraphRunner:
self.raw_num_token = raw_num_token self.raw_num_token = raw_num_token
self.bs = bs self.bs = bs
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): def replay(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
) -> LogitsProcessorOutput:
if not skip_attn_backend_init: if not skip_attn_backend_init:
self.replay_prepare(forward_batch) self.replay_prepare(forward_batch)
else: else:
......
...@@ -285,7 +285,6 @@ class ServerArgs: ...@@ -285,7 +285,6 @@ class ServerArgs:
if self.speculative_algorithm == "EAGLE": if self.speculative_algorithm == "EAGLE":
if self.max_running_requests is None: if self.max_running_requests is None:
self.max_running_requests = 32 self.max_running_requests = 32
self.disable_cuda_graph_padding = True
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
logger.info( logger.info(
"Overlap scheduler is disabled because of using " "Overlap scheduler is disabled because of using "
......
...@@ -3,8 +3,13 @@ ...@@ -3,8 +3,13 @@
from typing import List from typing import List
import torch import torch
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient from sglang.srt.utils import is_cuda_available
if is_cuda_available():
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
def build_tree_kernel_efficient_preprocess( def build_tree_kernel_efficient_preprocess(
...@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess( ...@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1) top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
top_scores_index = top_scores.indices top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
...@@ -108,296 +112,6 @@ def build_tree_kernel_efficient( ...@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
) )
def build_tree_kernel(
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
parents_list: List[torch.Tensor],
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_tokens: int,
):
parent_list, top_scores_index, draft_tokens = (
build_tree_kernel_efficient_preprocess(
verified_id,
score_list,
token_list,
parents_list,
num_verify_tokens,
)
)
bs = seq_lens.numel()
device = seq_lens.device
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
retrive_index = torch.full(
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
)
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
sgl_build_tree_kernel(
parent_list,
top_scores_index,
seq_lens.to(torch.int32),
tree_mask,
positions,
retrive_index,
topk,
spec_steps,
num_verify_tokens,
)
index = retrive_index.sum(dim=-1) != -spec_steps - 2
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
retrive_cum_len = torch.zeros(
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
)
retrive_cum_len[1:] = cum_len
# TODO: this indexing cause a synchronization, optimize this
retrive_index = retrive_index[index]
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
def test_build_tree_kernel():
def findp(p_i, index, parent_list):
pos = index // 10
index_list = index.tolist()
parent_list = parent_list.tolist()
res = [p_i]
while True:
p = pos[p_i]
if p == 0:
break
token_idx = parent_list[p]
p_i = index_list.index(token_idx)
res.append(p_i)
return res
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
mask = []
positions = []
retrive_index = []
for i, lens in enumerate(seq_len.tolist()):
first_mask = torch.full((lens + draft_token,), True)
first_mask[-(draft_token - 1) :] = False
positions.append(lens)
mask.append(first_mask)
seq_order = []
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
r_index = [first_index]
for j in range(draft_token - 1):
mask.append(torch.full((lens + 1,), True))
idx = findp(j, index, parent_list)
seq_order.append(idx)
positions.append(len(idx) + seq_len)
t = torch.full((draft_token - 1,), False)
t[idx] = True
mask.append(t)
for i in range(1, draft_token - 1):
is_leaf = 0
for j in range(draft_token - 1):
if i in seq_order[j]:
is_leaf += 1
if is_leaf == 1:
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
for _ in range(max_depth + 1 - len(seq_order[i])):
order_list.append(-1)
order = torch.Tensor(order_list).cuda().to(torch.long)
r_index.append(order)
retrive_index.append(torch.stack(r_index))
return (
torch.cat(mask).cuda(),
torch.Tensor(positions).cuda().to(torch.long),
torch.stack(retrive_index),
)
index = (
torch.Tensor(
[
0,
1,
2,
3,
10,
11,
12,
13,
20,
21,
22,
30,
110,
130,
150,
160,
210,
211,
212,
213,
214,
215,
216,
217,
218,
219,
220,
230,
310,
311,
312,
313,
314,
315,
316,
317,
320,
321,
322,
330,
360,
380,
390,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
423,
430,
431,
440,
441,
460,
470,
]
)
.to(torch.long)
.cuda()
)
parent_list = (
torch.Tensor(
[
-1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
20,
30,
21,
13,
22,
40,
23,
110,
130,
160,
150,
190,
120,
111,
121,
200,
180,
210,
211,
212,
213,
214,
215,
216,
220,
230,
217,
310,
311,
312,
313,
320,
314,
321,
315,
316,
317,
]
)
.to(torch.long)
.cuda()
)
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
bs = verified_seq_len.shape[0]
topk = 10
depth = 5 # depth <= 10
num_draft_token = 64
tree_mask = torch.full(
(
torch.sum(verified_seq_len).item() * num_draft_token
+ num_draft_token * num_draft_token * bs,
),
True,
).cuda()
retrive_index = torch.full(
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
)
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
sgl_build_tree_kernel(
parent_list.unsqueeze(0),
index.unsqueeze(0),
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
num_draft_token,
)
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
c_mask, c_positions, c_retive_index = create_mask(
verified_seq_len, num_draft_token, index, parent_list, depth
)
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
assert torch.allclose(positions, c_positions), "positions has error."
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
def test_build_tree_kernel_efficient(): def test_build_tree_kernel_efficient():
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32) verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
score_list = [ score_list = [
...@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient(): ...@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
depth = 4 depth = 4
num_draft_token = 8 num_draft_token = 8
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id=verified_id,
score_list=score_list,
token_list=token_list,
parents_list=parents_list,
seq_lens=seq_lens,
seq_lens_sum=torch.sum(seq_lens).item(),
topk=topk,
spec_steps=depth,
num_verify_tokens=num_draft_token,
)
)
from sglang.srt.utils import first_rank_print
first_rank_print("=========== build tree kernel ==========")
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print(f"{position=}", flush=True)
first_rank_print(f"{retrive_index=}", flush=True)
first_rank_print(f"{retrive_cum_len=}", flush=True)
first_rank_print(f"{draft_tokens=}", flush=True)
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [
[0, -1, -1, -1, -1, -1],
[0, 2, 4, 6, -1, -1],
[0, 1, 3, 5, 7, -1],
[8, -1, -1, -1, -1, -1],
[8, 9, 10, -1, -1, -1],
[8, 9, 12, -1, -1, -1],
[8, 9, 13, -1, -1, -1],
[8, 9, 11, 14, 15, -1],
]
assert retrive_cum_len.tolist() == [0, 3, 8]
assert draft_tokens.tolist() == [
29974,
29896,
29906,
29889,
29974,
29946,
29896,
29946,
13,
13,
22550,
4136,
16492,
8439,
29871,
29941,
]
( (
tree_mask, tree_mask,
position, position,
...@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient(): ...@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
if __name__ == "__main__": if __name__ == "__main__":
test_build_tree_kernel_efficient() test_build_tree_kernel_efficient()
test_build_tree_kernel()
...@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput ...@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
import logging
logger = logging.getLogger(__name__)
class EAGLEDraftCudaGraphRunner: class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker): def __init__(self, eagle_worker: EAGLEWorker):
...@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner: ...@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.topk = model_runner.server_args.speculative_eagle_topk self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.speculative_num_steps = model_runner.server_args.speculative_num_steps
server_args = model_runner.server_args server_args = model_runner.server_args
assert self.disable_padding
# Batch sizes to capture # Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.num_tokens_per_bs = server_args.speculative_eagle_topk self.num_tokens_per_bs = server_args.speculative_eagle_topk
...@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner: ...@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner:
set_global_graph_memory_pool(graph.pool()) set_global_graph_memory_pool(graph.pool())
return graph, out return graph, out
def _postprocess_output_to_raw_bs(self, out, raw_bs):
score_list, token_list, parents_list = out
score_list = [x[:raw_bs] for x in score_list]
token_list = [x[:raw_bs] for x in token_list]
parents_list = [x[:raw_bs] for x in parents_list]
return (score_list, token_list, parents_list)
def replay(self, forward_batch: ForwardBatch): def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
...@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner: ...@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.positions.zero_()
num_tokens = bs * self.num_tokens_per_bs
# Common inputs # Common inputs
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
...@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner: ...@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner:
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# Attention backend # Attention backend
if bs != raw_bs:
forward_batch.batch_size = bs
forward_batch.seq_lens = self.seq_lens[:bs]
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, forward_batch.batch_size forward_batch, bs
) )
# Replay # Replay
self.graphs[bs].replay() self.graphs[bs].replay()
out = self.output_buffers[bs]
if bs != raw_bs:
out = self._postprocess_output_to_raw_bs(out, raw_bs)
forward_batch.batch_size = raw_bs
forward_batch.positions = self.positions[:raw_num_token]
forward_batch.seq_lens = self.seq_lens[:raw_bs]
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
return self.output_buffers[bs] return out
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -13,18 +13,24 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput ...@@ -13,18 +13,24 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import ( from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
build_tree_kernel,
build_tree_kernel_efficient,
)
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
if is_cuda_available(): if is_cuda_available():
from sgl_kernel import tree_speculative_sampling_target_only from sgl_kernel import (
top_k_renorm_prob,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
import logging
logger = logging.getLogger(__name__)
@dataclass @dataclass
class EagleDraftInput: class EagleDraftInput:
...@@ -47,12 +53,9 @@ class EagleDraftInput: ...@@ -47,12 +53,9 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None kv_indices: torch.Tensor = None
# indices of unfinished requests during extend-after-decode all_padding_lens: Optional[torch.Tensor] = None
# e.g. [0, 2, 3, 4] if only the 1st request is finished
keep_indices: List[int] = None
def prepare_for_extend(self, batch: ScheduleBatch): def prepare_for_extend(self, batch: ScheduleBatch):
assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
# Prefill only generate 1 token. # Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens) assert len(self.verified_id) == len(batch.seq_lens)
...@@ -64,27 +67,18 @@ class EagleDraftInput: ...@@ -64,27 +67,18 @@ class EagleDraftInput:
) )
pt += extend_len pt += extend_len
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): def prepare_extend_after_decode(
assert self.verified_id.numel() == batch.out_cache_loc.shape[0] self,
batch: ScheduleBatch,
speculative_num_steps: int,
):
assert len(self.verified_id) == len(batch.out_cache_loc)
accept_length_cpu = batch.spec_info.accept_length_cpu accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens) batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist() seq_lens_cpu = batch.seq_lens.tolist()
assert len(batch.req_pool_indices) == len(batch.reqs)
pt = 0
i = 0
self.keep_indices = []
for idx, req in enumerate(batch.reqs):
if req.finished():
continue
self.keep_indices.append(idx)
# assert seq_len - pre_len == req.extend_input_len
input_len = batch.extend_lens[i]
seq_len = seq_lens_cpu[i]
pt += input_len
i += 1
self.positions = torch.empty_like(self.verified_id, dtype=torch.long) self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
...@@ -112,10 +106,6 @@ class EagleDraftInput: ...@@ -112,10 +106,6 @@ class EagleDraftInput:
req_to_token: torch.Tensor, req_to_token: torch.Tensor,
): ):
bs = self.accept_length.numel() bs = self.accept_length.numel()
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
req_pool_indices = req_pool_indices[keep_indices]
assert req_pool_indices.shape[0] == bs
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
...@@ -172,7 +162,7 @@ class EagleVerifyOutput: ...@@ -172,7 +162,7 @@ class EagleVerifyOutput:
# Accepeted token length per sequence in a batch in CPU. # Accepeted token length per sequence in a batch in CPU.
accept_length_per_req_cpu: List[int] accept_length_per_req_cpu: List[int]
# Accepeted indices from logits_output.next_token_logits # Accepeted indices from logits_output.next_token_logits
accepeted_indices_cpu: List[int] accepeted_indices: torch.Tensor
@dataclass @dataclass
...@@ -200,36 +190,7 @@ class EagleVerifyInput: ...@@ -200,36 +190,7 @@ class EagleVerifyInput:
topk: int, topk: int,
spec_steps: int, spec_steps: int,
num_verify_tokens: int, num_verify_tokens: int,
is_all_greedy: bool,
): ):
if is_all_greedy:
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id,
score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
token_list,
parents_list,
seq_lens,
seq_lens_sum,
topk,
spec_steps,
num_verify_tokens,
)
)
return cls(
draft_tokens,
tree_mask,
position,
retrive_index,
None,
None,
retrive_cum_len,
num_verify_tokens,
spec_steps,
CaptureHiddenMode.FULL,
)
else:
( (
tree_mask, tree_mask,
position, position,
...@@ -291,7 +252,6 @@ class EagleVerifyInput: ...@@ -291,7 +252,6 @@ class EagleVerifyInput:
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
cum_kv_seq_len = torch.zeros( cum_kv_seq_len = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda" (batch_size + 1,), dtype=torch.int32, device="cuda"
) )
...@@ -304,7 +264,6 @@ class EagleVerifyInput: ...@@ -304,7 +264,6 @@ class EagleVerifyInput:
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
create_flashinfer_kv_indices_triton[(batch_size,)]( create_flashinfer_kv_indices_triton[(batch_size,)](
req_to_token, req_to_token,
req_pool_indices, req_pool_indices,
...@@ -322,65 +281,79 @@ class EagleVerifyInput: ...@@ -322,65 +281,79 @@ class EagleVerifyInput:
logits_output: torch.Tensor, logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
"""WARNING: This API in-place modifies the states of logits_output """
Verify and find accepted tokens based on logits output and batch Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information). (which contains spec decoding information).
WARNING: This API in-place modifies the states of logits_output
This API updates values inside logits_output based on the accepted This API updates values inside logits_output based on the accepted
tokens. I.e., logits_output.next_token_logits only contains tokens. I.e., logits_output.next_token_logits only contains
accepeted token logits. accepeted token logits.
""" """
draft_token = torch.cat(
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
dim=-1,
)
candidates = draft_token[self.retrive_index]
if batch.sampling_info.is_all_greedy:
# temp == 0
bs = self.retrive_cum_len.numel() - 1
predict = torch.argmax(logits_output.next_token_logits, dim=-1)
predict = torch.cat(
[predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1
)
target_predict = predict[self.retrive_index]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask = candidates[:, 1:] == target_predict[:, :-1]
accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
max_draft_len = self.retrive_index.shape[-1]
accept_index = torch.full(
(bs, max_draft_len), -1, dtype=torch.int32, device="cuda"
)
accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
eagle_verify_retrive[(bs,)](
self.retrive_index.contiguous(),
accept_mask.contiguous(),
self.retrive_cum_len,
accept_index,
accept_length,
extract_index,
max_draft_len,
self.draft_token_num,
triton.next_power_of_2(max_draft_len),
)
else:
# temp > 0
bs = self.retrive_index.shape[0] bs = self.retrive_index.shape[0]
candidates = self.draft_token.reshape(bs, self.draft_token_num)
sampling_info = batch.sampling_info
predict_shape = list(logits_output.next_token_logits.shape)[:-1] predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1 predict_shape[-1] += 1
target_logits = logits_output.next_token_logits[self.retrive_index] predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda")
accept_index = torch.full( accept_index = torch.full(
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda" (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
) )
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1)
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) if sampling_info.penalizer_orchestrator.is_required:
draft_probs = torch.full_like( # This is a relaxed version of penalties for speculative decoding.
target_probs, 0, dtype=torch.float32, device="cuda" linear_penalty = torch.zeros(
(bs, logits_output.next_token_logits.shape[1]),
dtype=torch.float32,
device="cuda",
)
sampling_info.apply_logits_bias(linear_penalty)
logits_output.next_token_logits.add_(
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
)
if batch.sampling_info.is_all_greedy:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num)
verify_tree_greedy(
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates.to(torch.int32),
retrive_index=self.retrive_index.to(torch.int32),
retrive_next_token=self.retrive_next_token.to(torch.int32),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
target_predict=target_predict.to(torch.int32),
)
else:
# apply temperature and get target probs
expanded_temperature = torch.repeat_interleave(
sampling_info.temperatures, self.draft_token_num, dim=0
) # (bs * draft_token_num, 1)
target_probs = F.softmax(
logits_output.next_token_logits / expanded_temperature, dim=-1
) # (bs * draft_token_num, vocab_size)
target_probs = top_k_renorm_prob(
target_probs,
torch.repeat_interleave(
sampling_info.top_ks, self.draft_token_num, dim=0
),
) # (bs * draft_token_num, vocab_size)
target_probs = top_p_renorm_prob(
target_probs,
torch.repeat_interleave(
sampling_info.top_ps, self.draft_token_num, dim=0
),
)
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
draft_probs = torch.zeros(
target_probs.shape, dtype=torch.float32, device="cuda"
) )
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda") coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
tree_speculative_sampling_target_only( tree_speculative_sampling_target_only(
...@@ -394,6 +367,12 @@ class EagleVerifyInput: ...@@ -394,6 +367,12 @@ class EagleVerifyInput:
uniform_samples=coins, uniform_samples=coins,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True, deterministic=True,
) )
...@@ -425,10 +404,45 @@ class EagleVerifyInput: ...@@ -425,10 +404,45 @@ class EagleVerifyInput:
new_accept_index.extend(new_accept_index_) new_accept_index.extend(new_accept_index_)
unfinished_index.append(i) unfinished_index.append(i)
req.spec_verify_ct += 1 req.spec_verify_ct += 1
accept_length = (accept_index != -1).sum(dim=1) - 1
if not has_finished:
accept_index = accept_index[accept_index != -1] accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask]
token_to_kv_pool_allocator.free(mem_need_free_idx)
batch.out_cache_loc = batch.out_cache_loc[accept_index]
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist() accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput()
draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
draft_input.verified_id = verified_id
draft_input.accept_length = accept_length
draft_input.accept_length_cpu = accept_length_cpu
draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
return EagleVerifyOutput(
draft_input=draft_input,
logits_output=logits_output,
verified_id=verified_id,
accept_length_per_req_cpu=accept_length_cpu,
accepeted_indices=accept_index,
)
else:
accept_length = (accept_index != -1).sum(dim=1) - 1
accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index] verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False evict_mask[accept_index] = False
...@@ -444,20 +458,31 @@ class EagleVerifyInput: ...@@ -444,20 +458,31 @@ class EagleVerifyInput:
triton.next_power_of_2(bs), triton.next_power_of_2(bs),
) )
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist()
draft_input = EagleDraftInput() draft_input = EagleDraftInput()
if len(new_accept_index) > 0: if len(new_accept_index) > 0:
new_accept_index = torch.tensor(new_accept_index, device="cuda") new_accept_index = torch.tensor(new_accept_index, device="cuda")
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[
new_accept_index
]
draft_input.verified_id = predict[new_accept_index] draft_input.verified_id = predict[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index] draft_input.accept_length = accept_length[unfinished_index]
draft_input.accept_length_cpu = [ draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index accept_length_cpu[i] for i in unfinished_index
] ]
if has_finished: if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] draft_input.seq_lens_for_draft_extend = batch.seq_lens[
unfinished_index
]
draft_input.req_pool_indices_for_draft_extend = (
batch.req_pool_indices[unfinished_index]
)
else: else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = (
batch.req_pool_indices
)
batch.out_cache_loc = batch.out_cache_loc[new_accept_index] batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
return EagleVerifyOutput( return EagleVerifyOutput(
...@@ -465,80 +490,9 @@ class EagleVerifyInput: ...@@ -465,80 +490,9 @@ class EagleVerifyInput:
logits_output=logits_output, logits_output=logits_output,
verified_id=verified_id, verified_id=verified_id,
accept_length_per_req_cpu=accept_length_cpu, accept_length_per_req_cpu=accept_length_cpu,
accepeted_indices_cpu=accept_index, accepeted_indices=accept_index,
)
@triton.jit
def eagle_verify_retrive(
retrive_index,
accept_mask,
retrive_cum_len,
accept_index,
accept_length,
extract_index,
max_len: tl.constexpr,
draft_token_num: tl.constexpr,
max_len_upper: tl.constexpr,
):
"""
Args:
retrive_index: Pointer to indices of draft tokens
accept_mask: Mask indicating which tokens were accepted
retrive_cum_len: Cumulative lengths of token sequences in a batch
accept_index (out): Accept token indices
accept_length (out): Length of accepted tokens per sequence in a batch
extract_index (out): Index for last accepted tokens
max_len: Maximum length in a batch
draft_token_num: Number of tokens speculatively generated
max_len_upper An upper bound for token sequence length
"""
pid = tl.program_id(axis=0)
retrive_end = tl.load(retrive_cum_len + pid + 1)
retrive_start = tl.load(retrive_cum_len + pid)
retrive_len = retrive_end - retrive_start
accept_ptr = accept_mask + retrive_start
accept_offset = tl.arange(0, draft_token_num)
accept_load_mask = accept_offset < retrive_len
accept_len_list = tl.load(
accept_ptr + accept_offset, mask=accept_load_mask, other=-1
)
accept_len = tl.max(accept_len_list)
max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
# triton is not support argmax with tie_break_right, so I need implement it by some way
mask_max = accept_len_list == accept_len
count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
count = tl.sum(tl.where(mask_max, 1, count_mask))
if count > 1:
index = tl.arange(0, draft_token_num)
mask_left = index != max_index
remained_index = tl.where(mask_max and mask_left, index, 0)
max_index = tl.max(remained_index)
tl.store(accept_length + pid, accept_len)
retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
retrive_offset = tl.arange(0, max_len_upper)
retrive_load_mask = retrive_offset < accept_len + 1
data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
tl.store(
accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
) )
extract_load_ptr = accept_index + pid * max_len + accept_len
if accept_len == max_len - 1:
extract_data = tl.load(extract_load_ptr - 1)
tl.store(extract_index + pid * 2, extract_data)
extract_data = tl.load(extract_load_ptr)
tl.store(extract_index + pid * 2 + 1, extract_data)
else:
extract_data = tl.load(extract_load_ptr)
tl.store(extract_index + pid * 2, extract_data)
@triton.jit @triton.jit
def create_extend_spec_info( def create_extend_spec_info(
......
import logging import logging
import os import os
import time import time
from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk, fast_topk,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.utils import get_available_gpu_memory from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
if is_cuda_available():
from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
yield
class EAGLEWorker(TpModelWorker): class EAGLEWorker(TpModelWorker):
def __init__( def __init__(
...@@ -76,6 +90,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -76,6 +90,7 @@ class EAGLEWorker(TpModelWorker):
self.hot_token_id = None self.hot_token_id = None
# Init draft worker # Init draft worker
with empty_context():
super().__init__( super().__init__(
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
...@@ -94,10 +109,15 @@ class EAGLEWorker(TpModelWorker): ...@@ -94,10 +109,15 @@ class EAGLEWorker(TpModelWorker):
self.hot_token_id = self.hot_token_id.to(head.device) self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id] head.data = head.data[self.hot_token_id]
self.draft_model_runner.model.set_embed_and_head(embed, head) self.draft_model_runner.model.set_embed_and_head(embed, head)
# Init attention backend and cuda graphs
self.draft_model_runner.server_args.disable_cuda_graph = ( self.draft_model_runner.server_args.disable_cuda_graph = (
backup_disable_cuda_graph backup_disable_cuda_graph
) )
self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.init_attention_backend() self.init_attention_backend()
self.init_cuda_graphs() self.init_cuda_graphs()
...@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker): ...@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
) )
self.draft_attn_backend = FlashInferMultiStepDraftBackend( self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner, self.draft_model_runner,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton": elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import ( from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend, TritonMultiStepDraftBackend,
) )
self.draft_attn_backend = TritonMultiStepDraftBackend( self.draft_attn_backend = TritonMultiStepDraftBackend(
self.model_runner, self.draft_model_runner,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashinfer_mla": elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import ( from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend, FlashInferMLAMultiStepDraftBackend,
) )
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.model_runner, self.draft_model_runner,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
else: else:
raise ValueError( raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
) )
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
self.cuda_graph_runner = None self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph: if self.server_args.disable_cuda_graph:
return return
# Capture draft
tic = time.time() tic = time.time()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
) )
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
) )
# Capture extend
if self.draft_extend_attn_backend:
raise NotImplementedError()
@property @property
def draft_model_runner(self): def draft_model_runner(self):
return self.model_runner return self.model_runner
...@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
) -> Tuple[LogitsProcessorOutput, List[int], int, int]: ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
"""Run speculative decoding forward. """Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed NOTE: Many states of batch is modified as you go through. It is not guaranteed that
the final output batch doesn't have the same state as the input. the final output batch have the same state as the input.
Args: Args:
batch: The batch to run forward. The state of the batch is modified as it runs. batch: The batch to run forward. The state of the batch is modified as it runs.
...@@ -173,27 +211,39 @@ class EAGLEWorker(TpModelWorker): ...@@ -173,27 +211,39 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepeted, A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens. the batch id (used for overlap schedule), and number of accepeted tokens.
""" """
assert not batch.spec_algorithm.is_none()
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info, to_free_cache_loc = self.draft(batch) spec_info, to_free_cache_loc = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify( logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info batch, spec_info
) )
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.) # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self.token_to_kv_pool_allocator.free(to_free_cache_loc) self.token_to_kv_pool_allocator.free(to_free_cache_loc)
# if it is None, means all requests are finished
# If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None: if batch.spec_info.verified_id is not None:
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(batch) self.forward_draft_extend_after_decode(batch)
return ( return (
logits_output, logits_output,
verify_output.verified_id, verify_output.verified_id,
model_worker_batch.bid, model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu), sum(verify_output.accept_length_per_req_cpu),
) )
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids, _ = (
self.target_worker.forward_batch_generation(
ForwardBatch.init_new(
model_worker_batch, self.target_worker.model_runner
)
)
)
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else: else:
logits_output, next_token_ids, bid = self.forward_target_extend(batch) logits_output, next_token_ids, bid = self.forward_target_extend(batch)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend( self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids batch, logits_output.hidden_states, next_token_ids
) )
...@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker): ...@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
num_seqs = batch.batch_size() num_seqs = batch.batch_size()
spec_info = batch.spec_info spec_info = batch.spec_info
# Accumulate penalty
if batch.sampling_info.penalizer_orchestrator.is_required:
# This is a relaxed version of penalties for speculative decoding.
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
spec_info.verified_id.to(torch.int64)
)
# Allocate cache locations # Allocate cache locations
out_cache_loc = batch.alloc_token_slots( out_cache_loc = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps num_seqs * self.topk * self.speculative_num_steps
...@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens, self.server_args.speculative_num_draft_tokens,
batch.sampling_info.is_all_greedy,
) )
return ret, out_cache_loc return ret, out_cache_loc
def draft_forward(self, forward_batch: ForwardBatch): def draft_forward(self, forward_batch: ForwardBatch):
...@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
token_list.append(tree_info[1]) token_list.append(tree_info[1])
parents_list.append(tree_info[2]) parents_list.append(tree_info[2])
# we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
if i == self.speculative_num_steps - 1: if i == self.speculative_num_steps - 1:
break break
...@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
spec_info.hidden_states = hidden_states spec_info.hidden_states = hidden_states
# Run forward # Run forward
logits_output = self.model_runner.model.forward( logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
...@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
# Post process based on verified outputs. # Post process based on verified outputs.
# Pick indices that we care (accepeted) # Pick indices that we care (accepeted)
logits_output.next_token_logits = logits_output.next_token_logits[ logits_output.next_token_logits = logits_output.next_token_logits[
res.accepeted_indices_cpu res.accepeted_indices
]
logits_output.hidden_states = logits_output.hidden_states[
res.accepeted_indices_cpu
] ]
logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
# Prepare the batch for the next draft forwards. # Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input batch.spec_info = res.draft_input
...@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
batch_next_token_ids, batch_next_token_ids,
] ]
# Add output logprobs to the request. # Add output logprobs to the request
pt = 0 pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist() next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist() verified_ids = batch_next_token_ids.tolist()
...@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker): ...@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch): def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens # Backup fileds that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
# Prepare metadata
batch.forward_mode = ForwardMode.DRAFT_EXTEND batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# We don't need logprob for this extend.
original_return_logprob = batch.return_logprob
batch.return_logprob = False batch.return_logprob = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
# Run
logits_output = self.draft_model_runner.forward(forward_batch) logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.return_logprob = original_return_logprob
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup
batch.return_logprob = return_logprob_backup
def capture_for_decode( def capture_for_decode(
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
...@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
if self.enable_nan_detection: if self.enable_nan_detection:
logits = logits_output.next_token_logits logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)): if torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.") logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.") raise ValueError("Detected errors during sampling! NaN in the logits.")
......
...@@ -36,6 +36,7 @@ import tempfile ...@@ -36,6 +36,7 @@ import tempfile
import threading import threading
import time import time
import warnings import warnings
from contextlib import contextmanager
from functools import lru_cache from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec from importlib.util import find_spec
...@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int): ...@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int):
setattr(triton, "next_power_of_2", next_power_of_2) setattr(triton, "next_power_of_2", next_power_of_2)
@contextmanager
def empty_context(*args, **kwargs):
try:
# Setup code goes here
yield
finally:
# Cleanup code goes here
pass
def add_prefix(name: str, prefix: str) -> str: def add_prefix(name: str, prefix: str) -> str:
"""Add a weight path prefix to a module name. """Add a weight path prefix to a module name.
......
...@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa ...@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
# For compling xgrammar kernels # For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install cuda-python nvidia-cuda-nvrtc-cu12
# reinstall sgl-kernel
pip install sgl-kernel==0.0.5.post1 --force-reinstall --no-deps
...@@ -36,8 +36,8 @@ template < ...@@ -36,8 +36,8 @@ template <
typename DType, typename DType,
typename IdType> typename IdType>
__global__ void TreeSpeculativeSamplingTargetOnly( __global__ void TreeSpeculativeSamplingTargetOnly(
IdType* predicts, IdType* predicts, // mutable
IdType* accept_index, IdType* accept_index, // mutable
IdType* accept_token_num, // mutable IdType* accept_token_num, // mutable
IdType* candidates, IdType* candidates,
IdType* retrive_index, IdType* retrive_index,
...@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly( ...@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
template <typename DType, typename IdType> template <typename DType, typename IdType>
cudaError_t TreeSpeculativeSamplingTargetOnly( cudaError_t TreeSpeculativeSamplingTargetOnly(
IdType* predicts, IdType* predicts, // mutable
IdType* output_token_ids, IdType* output_token_ids, // mutable
IdType* output_accepted_token_num, // mutable IdType* output_accepted_token_num, // mutable
IdType* candidates, IdType* candidates,
IdType* retrive_index, IdType* retrive_index,
......
...@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase):
def _test_acc_length(self, engine): def _test_acc_length(self, engine):
prompt = [ prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
] * 5 ] * 5 # test batched generation
sampling_params = {"temperature": 0, "max_new_tokens": 512} sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params) output = engine.generate(prompt, sampling_params)
output = output[0] output = output[0]
......
...@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
"--enable-torch-compile", "--enable-torch-compile",
"--disable-cuda-graph", "--disable-cuda-graph",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--enable-flashinfer-mla", "--enable-flashinfer-mla",
"--flashinfer-mla-disable-ragged", "--flashinfer-mla-disable-ragged",
] ]
...@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase): ...@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase):
other_args.extend( other_args.extend(
[ [
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--disable-radix", "--disable-radix",
"--enable-torch-compile", "--enable-torch-compile",
"--torch-compile-max-bs", "--torch-compile-max-bs",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment