Unverified Commit 1b859295 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Eagle] Remove the greedy branch and some redundant code (#4363)


Co-authored-by: default avatarSehoon Kim <sehoon@x.ai>
parent 9971dc22
...@@ -43,7 +43,7 @@ runtime_common = [ ...@@ -43,7 +43,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.0.5.post1", "sgl-kernel==0.0.5.post2",
"flashinfer_python==0.2.3", "flashinfer_python==0.2.3",
"torch==2.5.1", "torch==2.5.1",
"vllm>=0.6.4.post1,<=0.7.2", "vllm>=0.6.4.post1,<=0.7.2",
......
...@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.post("/flush_cache") @app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache(): async def flush_cache():
"""Flush the radix cache.""" """Flush the radix cache."""
_global_state.tokenizer_manager.flush_cache() _global_state.tokenizer_manager.flush_cache()
......
...@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"largest-len: {self._largest_prefill_decode_len}, "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
) )
spec_accept_length = 0 spec_accept_length = 0
...@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {spec_accept_length:.2f}, " f"accept len: {spec_accept_length:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"largest-len: {self._largest_prefill_decode_len}, "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
) )
......
...@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else: else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else: else:
capture_bs = list(range(1, 33)) # Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
if _is_hip: if _is_hip:
capture_bs += [i * 8 for i in range(21, 33)] capture_bs += [i * 8 for i in range(21, 33)]
...@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if max(capture_bs) > model_runner.req_to_token_pool.size: if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs. # is very small. We add more values here to make sure we capture the maximum bs.
capture_bs = list( capture_bs += [model_runner.req_to_token_pool.size - 1] + [
sorted( model_runner.req_to_token_pool.size
set( ]
capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
capture_bs = list(sorted(set(capture_bs)))
capture_bs = [ capture_bs = [
bs bs
for bs in capture_bs for bs in capture_bs
...@@ -508,7 +505,9 @@ class CudaGraphRunner: ...@@ -508,7 +505,9 @@ class CudaGraphRunner:
self.raw_num_token = raw_num_token self.raw_num_token = raw_num_token
self.bs = bs self.bs = bs
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): def replay(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
) -> LogitsProcessorOutput:
if not skip_attn_backend_init: if not skip_attn_backend_init:
self.replay_prepare(forward_batch) self.replay_prepare(forward_batch)
else: else:
......
...@@ -285,7 +285,6 @@ class ServerArgs: ...@@ -285,7 +285,6 @@ class ServerArgs:
if self.speculative_algorithm == "EAGLE": if self.speculative_algorithm == "EAGLE":
if self.max_running_requests is None: if self.max_running_requests is None:
self.max_running_requests = 32 self.max_running_requests = 32
self.disable_cuda_graph_padding = True
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
logger.info( logger.info(
"Overlap scheduler is disabled because of using " "Overlap scheduler is disabled because of using "
......
...@@ -3,8 +3,13 @@ ...@@ -3,8 +3,13 @@
from typing import List from typing import List
import torch import torch
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient from sglang.srt.utils import is_cuda_available
if is_cuda_available():
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
def build_tree_kernel_efficient_preprocess( def build_tree_kernel_efficient_preprocess(
...@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess( ...@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1) top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
top_scores_index = top_scores.indices top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
...@@ -108,296 +112,6 @@ def build_tree_kernel_efficient( ...@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
) )
def build_tree_kernel(
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
parents_list: List[torch.Tensor],
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_tokens: int,
):
parent_list, top_scores_index, draft_tokens = (
build_tree_kernel_efficient_preprocess(
verified_id,
score_list,
token_list,
parents_list,
num_verify_tokens,
)
)
bs = seq_lens.numel()
device = seq_lens.device
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
retrive_index = torch.full(
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
)
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
sgl_build_tree_kernel(
parent_list,
top_scores_index,
seq_lens.to(torch.int32),
tree_mask,
positions,
retrive_index,
topk,
spec_steps,
num_verify_tokens,
)
index = retrive_index.sum(dim=-1) != -spec_steps - 2
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
retrive_cum_len = torch.zeros(
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
)
retrive_cum_len[1:] = cum_len
# TODO: this indexing cause a synchronization, optimize this
retrive_index = retrive_index[index]
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
def test_build_tree_kernel():
def findp(p_i, index, parent_list):
pos = index // 10
index_list = index.tolist()
parent_list = parent_list.tolist()
res = [p_i]
while True:
p = pos[p_i]
if p == 0:
break
token_idx = parent_list[p]
p_i = index_list.index(token_idx)
res.append(p_i)
return res
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
mask = []
positions = []
retrive_index = []
for i, lens in enumerate(seq_len.tolist()):
first_mask = torch.full((lens + draft_token,), True)
first_mask[-(draft_token - 1) :] = False
positions.append(lens)
mask.append(first_mask)
seq_order = []
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
r_index = [first_index]
for j in range(draft_token - 1):
mask.append(torch.full((lens + 1,), True))
idx = findp(j, index, parent_list)
seq_order.append(idx)
positions.append(len(idx) + seq_len)
t = torch.full((draft_token - 1,), False)
t[idx] = True
mask.append(t)
for i in range(1, draft_token - 1):
is_leaf = 0
for j in range(draft_token - 1):
if i in seq_order[j]:
is_leaf += 1
if is_leaf == 1:
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
for _ in range(max_depth + 1 - len(seq_order[i])):
order_list.append(-1)
order = torch.Tensor(order_list).cuda().to(torch.long)
r_index.append(order)
retrive_index.append(torch.stack(r_index))
return (
torch.cat(mask).cuda(),
torch.Tensor(positions).cuda().to(torch.long),
torch.stack(retrive_index),
)
index = (
torch.Tensor(
[
0,
1,
2,
3,
10,
11,
12,
13,
20,
21,
22,
30,
110,
130,
150,
160,
210,
211,
212,
213,
214,
215,
216,
217,
218,
219,
220,
230,
310,
311,
312,
313,
314,
315,
316,
317,
320,
321,
322,
330,
360,
380,
390,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
423,
430,
431,
440,
441,
460,
470,
]
)
.to(torch.long)
.cuda()
)
parent_list = (
torch.Tensor(
[
-1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
20,
30,
21,
13,
22,
40,
23,
110,
130,
160,
150,
190,
120,
111,
121,
200,
180,
210,
211,
212,
213,
214,
215,
216,
220,
230,
217,
310,
311,
312,
313,
320,
314,
321,
315,
316,
317,
]
)
.to(torch.long)
.cuda()
)
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
bs = verified_seq_len.shape[0]
topk = 10
depth = 5 # depth <= 10
num_draft_token = 64
tree_mask = torch.full(
(
torch.sum(verified_seq_len).item() * num_draft_token
+ num_draft_token * num_draft_token * bs,
),
True,
).cuda()
retrive_index = torch.full(
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
)
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
sgl_build_tree_kernel(
parent_list.unsqueeze(0),
index.unsqueeze(0),
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
num_draft_token,
)
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
c_mask, c_positions, c_retive_index = create_mask(
verified_seq_len, num_draft_token, index, parent_list, depth
)
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
assert torch.allclose(positions, c_positions), "positions has error."
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
def test_build_tree_kernel_efficient(): def test_build_tree_kernel_efficient():
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32) verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
score_list = [ score_list = [
...@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient(): ...@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
depth = 4 depth = 4
num_draft_token = 8 num_draft_token = 8
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id=verified_id,
score_list=score_list,
token_list=token_list,
parents_list=parents_list,
seq_lens=seq_lens,
seq_lens_sum=torch.sum(seq_lens).item(),
topk=topk,
spec_steps=depth,
num_verify_tokens=num_draft_token,
)
)
from sglang.srt.utils import first_rank_print
first_rank_print("=========== build tree kernel ==========")
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print(f"{position=}", flush=True)
first_rank_print(f"{retrive_index=}", flush=True)
first_rank_print(f"{retrive_cum_len=}", flush=True)
first_rank_print(f"{draft_tokens=}", flush=True)
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [
[0, -1, -1, -1, -1, -1],
[0, 2, 4, 6, -1, -1],
[0, 1, 3, 5, 7, -1],
[8, -1, -1, -1, -1, -1],
[8, 9, 10, -1, -1, -1],
[8, 9, 12, -1, -1, -1],
[8, 9, 13, -1, -1, -1],
[8, 9, 11, 14, 15, -1],
]
assert retrive_cum_len.tolist() == [0, 3, 8]
assert draft_tokens.tolist() == [
29974,
29896,
29906,
29889,
29974,
29946,
29896,
29946,
13,
13,
22550,
4136,
16492,
8439,
29871,
29941,
]
( (
tree_mask, tree_mask,
position, position,
...@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient(): ...@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
if __name__ == "__main__": if __name__ == "__main__":
test_build_tree_kernel_efficient() test_build_tree_kernel_efficient()
test_build_tree_kernel()
...@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput ...@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
import logging
logger = logging.getLogger(__name__)
class EAGLEDraftCudaGraphRunner: class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker): def __init__(self, eagle_worker: EAGLEWorker):
...@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner: ...@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.topk = model_runner.server_args.speculative_eagle_topk self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.speculative_num_steps = model_runner.server_args.speculative_num_steps
server_args = model_runner.server_args server_args = model_runner.server_args
assert self.disable_padding
# Batch sizes to capture # Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.num_tokens_per_bs = server_args.speculative_eagle_topk self.num_tokens_per_bs = server_args.speculative_eagle_topk
...@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner: ...@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner:
set_global_graph_memory_pool(graph.pool()) set_global_graph_memory_pool(graph.pool())
return graph, out return graph, out
def _postprocess_output_to_raw_bs(self, out, raw_bs):
score_list, token_list, parents_list = out
score_list = [x[:raw_bs] for x in score_list]
token_list = [x[:raw_bs] for x in token_list]
parents_list = [x[:raw_bs] for x in parents_list]
return (score_list, token_list, parents_list)
def replay(self, forward_batch: ForwardBatch): def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
...@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner: ...@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.positions.zero_()
num_tokens = bs * self.num_tokens_per_bs
# Common inputs # Common inputs
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
...@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner: ...@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner:
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# Attention backend # Attention backend
if bs != raw_bs:
forward_batch.batch_size = bs
forward_batch.seq_lens = self.seq_lens[:bs]
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, forward_batch.batch_size forward_batch, bs
) )
# Replay # Replay
self.graphs[bs].replay() self.graphs[bs].replay()
out = self.output_buffers[bs]
if bs != raw_bs:
out = self._postprocess_output_to_raw_bs(out, raw_bs)
forward_batch.batch_size = raw_bs
forward_batch.positions = self.positions[:raw_num_token]
forward_batch.seq_lens = self.seq_lens[:raw_bs]
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
return self.output_buffers[bs] return out
import logging import logging
import os import os
import time import time
from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk, fast_topk,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.utils import get_available_gpu_memory from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
if is_cuda_available():
from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
yield
class EAGLEWorker(TpModelWorker): class EAGLEWorker(TpModelWorker):
def __init__( def __init__(
...@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker): ...@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker):
self.hot_token_id = None self.hot_token_id = None
# Init draft worker # Init draft worker
super().__init__( with empty_context():
gpu_id=gpu_id, super().__init__(
tp_rank=tp_rank, gpu_id=gpu_id,
server_args=server_args, tp_rank=tp_rank,
nccl_port=nccl_port, server_args=server_args,
dp_rank=dp_rank, nccl_port=nccl_port,
is_draft_worker=True, dp_rank=dp_rank,
req_to_token_pool=self.req_to_token_pool, is_draft_worker=True,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, req_to_token_pool=self.req_to_token_pool,
) token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
# Share the embedding and lm_head # Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head() embed, head = self.target_worker.model_runner.model.get_embed_and_head()
...@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker): ...@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker):
self.hot_token_id = self.hot_token_id.to(head.device) self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id] head.data = head.data[self.hot_token_id]
self.draft_model_runner.model.set_embed_and_head(embed, head) self.draft_model_runner.model.set_embed_and_head(embed, head)
# Init attention backend and cuda graphs
self.draft_model_runner.server_args.disable_cuda_graph = ( self.draft_model_runner.server_args.disable_cuda_graph = (
backup_disable_cuda_graph backup_disable_cuda_graph
) )
self.draft_tp_context = (
self.init_attention_backend() draft_tp_context if server_args.enable_dp_attention else empty_context
self.init_cuda_graphs() )
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.init_attention_backend()
self.init_cuda_graphs()
def init_attention_backend(self): def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
...@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker): ...@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
) )
self.draft_attn_backend = FlashInferMultiStepDraftBackend( self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner, self.draft_model_runner,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton": elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import ( from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend, TritonMultiStepDraftBackend,
) )
self.draft_attn_backend = TritonMultiStepDraftBackend( self.draft_attn_backend = TritonMultiStepDraftBackend(
self.model_runner, self.draft_model_runner,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashinfer_mla": elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import ( from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend, FlashInferMLAMultiStepDraftBackend,
) )
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.model_runner, self.draft_model_runner,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
else: else:
raise ValueError( raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
) )
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
self.cuda_graph_runner = None self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph: if self.server_args.disable_cuda_graph:
return return
# Capture draft
tic = time.time() tic = time.time()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
) )
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info( logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
) )
# Capture extend
if self.draft_extend_attn_backend:
raise NotImplementedError()
@property @property
def draft_model_runner(self): def draft_model_runner(self):
return self.model_runner return self.model_runner
...@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
) -> Tuple[LogitsProcessorOutput, List[int], int, int]: ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
"""Run speculative decoding forward. """Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed NOTE: Many states of batch is modified as you go through. It is not guaranteed that
the final output batch doesn't have the same state as the input. the final output batch have the same state as the input.
Args: Args:
batch: The batch to run forward. The state of the batch is modified as it runs. batch: The batch to run forward. The state of the batch is modified as it runs.
...@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker): ...@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepeted, A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens. the batch id (used for overlap schedule), and number of accepeted tokens.
""" """
assert not batch.spec_algorithm.is_none()
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
spec_info, to_free_cache_loc = self.draft(batch) with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info, to_free_cache_loc = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify( logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info batch, spec_info
) )
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.) # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self.token_to_kv_pool_allocator.free(to_free_cache_loc) self.token_to_kv_pool_allocator.free(to_free_cache_loc)
# if it is None, means all requests are finished
if batch.spec_info.verified_id is not None:
self.forward_draft_extend_after_decode(batch)
# If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None:
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(batch)
return ( return (
logits_output, logits_output,
verify_output.verified_id, verify_output.verified_id,
model_worker_batch.bid, model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu), sum(verify_output.accept_length_per_req_cpu),
) )
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids, _ = (
self.target_worker.forward_batch_generation(
ForwardBatch.init_new(
model_worker_batch, self.target_worker.model_runner
)
)
)
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else: else:
logits_output, next_token_ids, bid = self.forward_target_extend(batch) logits_output, next_token_ids, bid = self.forward_target_extend(batch)
self.forward_draft_extend( with self.draft_tp_context(self.draft_model_runner.tp_group):
batch, logits_output.hidden_states, next_token_ids self.forward_draft_extend(
) batch, logits_output.hidden_states, next_token_ids
)
return logits_output, next_token_ids, bid, 0 return logits_output, next_token_ids, bid, 0
def forward_target_extend( def forward_target_extend(
...@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker): ...@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
num_seqs = batch.batch_size() num_seqs = batch.batch_size()
spec_info = batch.spec_info spec_info = batch.spec_info
# Accumulate penalty
if batch.sampling_info.penalizer_orchestrator.is_required:
# This is a relaxed version of penalties for speculative decoding.
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
spec_info.verified_id.to(torch.int64)
)
# Allocate cache locations # Allocate cache locations
out_cache_loc = batch.alloc_token_slots( out_cache_loc = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps num_seqs * self.topk * self.speculative_num_steps
...@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens, self.server_args.speculative_num_draft_tokens,
batch.sampling_info.is_all_greedy,
) )
return ret, out_cache_loc return ret, out_cache_loc
def draft_forward(self, forward_batch: ForwardBatch): def draft_forward(self, forward_batch: ForwardBatch):
...@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
token_list.append(tree_info[1]) token_list.append(tree_info[1])
parents_list.append(tree_info[2]) parents_list.append(tree_info[2])
# we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
if i == self.speculative_num_steps - 1: if i == self.speculative_num_steps - 1:
break break
...@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
spec_info.hidden_states = hidden_states spec_info.hidden_states = hidden_states
# Run forward # Run forward
logits_output = self.model_runner.model.forward( logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
...@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
# Post process based on verified outputs. # Post process based on verified outputs.
# Pick indices that we care (accepeted) # Pick indices that we care (accepeted)
logits_output.next_token_logits = logits_output.next_token_logits[ logits_output.next_token_logits = logits_output.next_token_logits[
res.accepeted_indices_cpu res.accepeted_indices
]
logits_output.hidden_states = logits_output.hidden_states[
res.accepeted_indices_cpu
] ]
logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
# Prepare the batch for the next draft forwards. # Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input batch.spec_info = res.draft_input
...@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
batch_next_token_ids, batch_next_token_ids,
] ]
# Add output logprobs to the request. # Add output logprobs to the request
pt = 0 pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist() next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist() verified_ids = batch_next_token_ids.tolist()
...@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker): ...@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch): def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens # Backup fileds that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
# Prepare metadata
batch.forward_mode = ForwardMode.DRAFT_EXTEND batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# We don't need logprob for this extend.
original_return_logprob = batch.return_logprob
batch.return_logprob = False batch.return_logprob = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
# Run
logits_output = self.draft_model_runner.forward(forward_batch) logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.return_logprob = original_return_logprob
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup
batch.return_logprob = return_logprob_backup
def capture_for_decode( def capture_for_decode(
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
...@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
if self.enable_nan_detection: if self.enable_nan_detection:
logits = logits_output.next_token_logits logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)): if torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.") logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.") raise ValueError("Detected errors during sampling! NaN in the logits.")
......
...@@ -36,6 +36,7 @@ import tempfile ...@@ -36,6 +36,7 @@ import tempfile
import threading import threading
import time import time
import warnings import warnings
from contextlib import contextmanager
from functools import lru_cache from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec from importlib.util import find_spec
...@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int): ...@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int):
setattr(triton, "next_power_of_2", next_power_of_2) setattr(triton, "next_power_of_2", next_power_of_2)
@contextmanager
def empty_context(*args, **kwargs):
try:
# Setup code goes here
yield
finally:
# Cleanup code goes here
pass
def add_prefix(name: str, prefix: str) -> str: def add_prefix(name: str, prefix: str) -> str:
"""Add a weight path prefix to a module name. """Add a weight path prefix to a module name.
......
...@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa ...@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
# For compling xgrammar kernels # For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install cuda-python nvidia-cuda-nvrtc-cu12
# reinstall sgl-kernel
pip install sgl-kernel==0.0.5.post1 --force-reinstall --no-deps
...@@ -36,8 +36,8 @@ template < ...@@ -36,8 +36,8 @@ template <
typename DType, typename DType,
typename IdType> typename IdType>
__global__ void TreeSpeculativeSamplingTargetOnly( __global__ void TreeSpeculativeSamplingTargetOnly(
IdType* predicts, IdType* predicts, // mutable
IdType* accept_index, IdType* accept_index, // mutable
IdType* accept_token_num, // mutable IdType* accept_token_num, // mutable
IdType* candidates, IdType* candidates,
IdType* retrive_index, IdType* retrive_index,
...@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly( ...@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
template <typename DType, typename IdType> template <typename DType, typename IdType>
cudaError_t TreeSpeculativeSamplingTargetOnly( cudaError_t TreeSpeculativeSamplingTargetOnly(
IdType* predicts, IdType* predicts, // mutable
IdType* output_token_ids, IdType* output_token_ids, // mutable
IdType* output_accepted_token_num, // mutable IdType* output_accepted_token_num, // mutable
IdType* candidates, IdType* candidates,
IdType* retrive_index, IdType* retrive_index,
......
...@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase):
def _test_acc_length(self, engine): def _test_acc_length(self, engine):
prompt = [ prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
] * 5 ] * 5 # test batched generation
sampling_params = {"temperature": 0, "max_new_tokens": 512} sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params) output = engine.generate(prompt, sampling_params)
output = output[0] output = output[0]
......
...@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
"--enable-torch-compile", "--enable-torch-compile",
"--disable-cuda-graph", "--disable-cuda-graph",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--enable-flashinfer-mla", "--enable-flashinfer-mla",
"--flashinfer-mla-disable-ragged", "--flashinfer-mla-disable-ragged",
] ]
...@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase): ...@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase):
other_args.extend( other_args.extend(
[ [
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--disable-radix", "--disable-radix",
"--enable-torch-compile", "--enable-torch-compile",
"--torch-compile-max-bs", "--torch-compile-max-bs",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment