Unverified Commit 1b859295 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Eagle] Remove the greedy branch and some redundant code (#4363)


Co-authored-by: default avatarSehoon Kim <sehoon@x.ai>
parent 9971dc22
......@@ -43,7 +43,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.0.5.post1",
"sgl-kernel==0.0.5.post2",
"flashinfer_python==0.2.3",
"torch==2.5.1",
"vllm>=0.6.4.post1,<=0.7.2",
......
......@@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e)
@app.post("/flush_cache")
@app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache():
"""Flush the radix cache."""
_global_state.tokenizer_manager.flush_cache()
......
......@@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"largest-len: {self._largest_prefill_decode_len}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
spec_accept_length = 0
......@@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {spec_accept_length:.2f}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"largest-len: {self._largest_prefill_decode_len}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
......
......@@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else:
capture_bs = list(range(1, 33))
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
if _is_hip:
capture_bs += [i * 8 for i in range(21, 33)]
......@@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs = list(
sorted(
set(
capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
capture_bs += [model_runner.req_to_token_pool.size - 1] + [
model_runner.req_to_token_pool.size
]
capture_bs = list(sorted(set(capture_bs)))
capture_bs = [
bs
for bs in capture_bs
......@@ -508,7 +505,9 @@ class CudaGraphRunner:
self.raw_num_token = raw_num_token
self.bs = bs
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
def replay(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
) -> LogitsProcessorOutput:
if not skip_attn_backend_init:
self.replay_prepare(forward_batch)
else:
......
......@@ -285,7 +285,6 @@ class ServerArgs:
if self.speculative_algorithm == "EAGLE":
if self.max_running_requests is None:
self.max_running_requests = 32
self.disable_cuda_graph_padding = True
self.disable_overlap_schedule = True
logger.info(
"Overlap scheduler is disabled because of using "
......
......@@ -3,8 +3,13 @@
from typing import List
import torch
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient
from sglang.srt.utils import is_cuda_available
if is_cuda_available():
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
def build_tree_kernel_efficient_preprocess(
......@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
......@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
)
def build_tree_kernel(
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
parents_list: List[torch.Tensor],
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_tokens: int,
):
parent_list, top_scores_index, draft_tokens = (
build_tree_kernel_efficient_preprocess(
verified_id,
score_list,
token_list,
parents_list,
num_verify_tokens,
)
)
bs = seq_lens.numel()
device = seq_lens.device
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
retrive_index = torch.full(
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
)
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
sgl_build_tree_kernel(
parent_list,
top_scores_index,
seq_lens.to(torch.int32),
tree_mask,
positions,
retrive_index,
topk,
spec_steps,
num_verify_tokens,
)
index = retrive_index.sum(dim=-1) != -spec_steps - 2
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
retrive_cum_len = torch.zeros(
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
)
retrive_cum_len[1:] = cum_len
# TODO: this indexing cause a synchronization, optimize this
retrive_index = retrive_index[index]
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
def test_build_tree_kernel():
def findp(p_i, index, parent_list):
pos = index // 10
index_list = index.tolist()
parent_list = parent_list.tolist()
res = [p_i]
while True:
p = pos[p_i]
if p == 0:
break
token_idx = parent_list[p]
p_i = index_list.index(token_idx)
res.append(p_i)
return res
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
mask = []
positions = []
retrive_index = []
for i, lens in enumerate(seq_len.tolist()):
first_mask = torch.full((lens + draft_token,), True)
first_mask[-(draft_token - 1) :] = False
positions.append(lens)
mask.append(first_mask)
seq_order = []
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
r_index = [first_index]
for j in range(draft_token - 1):
mask.append(torch.full((lens + 1,), True))
idx = findp(j, index, parent_list)
seq_order.append(idx)
positions.append(len(idx) + seq_len)
t = torch.full((draft_token - 1,), False)
t[idx] = True
mask.append(t)
for i in range(1, draft_token - 1):
is_leaf = 0
for j in range(draft_token - 1):
if i in seq_order[j]:
is_leaf += 1
if is_leaf == 1:
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
for _ in range(max_depth + 1 - len(seq_order[i])):
order_list.append(-1)
order = torch.Tensor(order_list).cuda().to(torch.long)
r_index.append(order)
retrive_index.append(torch.stack(r_index))
return (
torch.cat(mask).cuda(),
torch.Tensor(positions).cuda().to(torch.long),
torch.stack(retrive_index),
)
index = (
torch.Tensor(
[
0,
1,
2,
3,
10,
11,
12,
13,
20,
21,
22,
30,
110,
130,
150,
160,
210,
211,
212,
213,
214,
215,
216,
217,
218,
219,
220,
230,
310,
311,
312,
313,
314,
315,
316,
317,
320,
321,
322,
330,
360,
380,
390,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
423,
430,
431,
440,
441,
460,
470,
]
)
.to(torch.long)
.cuda()
)
parent_list = (
torch.Tensor(
[
-1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
20,
30,
21,
13,
22,
40,
23,
110,
130,
160,
150,
190,
120,
111,
121,
200,
180,
210,
211,
212,
213,
214,
215,
216,
220,
230,
217,
310,
311,
312,
313,
320,
314,
321,
315,
316,
317,
]
)
.to(torch.long)
.cuda()
)
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
bs = verified_seq_len.shape[0]
topk = 10
depth = 5 # depth <= 10
num_draft_token = 64
tree_mask = torch.full(
(
torch.sum(verified_seq_len).item() * num_draft_token
+ num_draft_token * num_draft_token * bs,
),
True,
).cuda()
retrive_index = torch.full(
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
)
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
sgl_build_tree_kernel(
parent_list.unsqueeze(0),
index.unsqueeze(0),
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
num_draft_token,
)
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
c_mask, c_positions, c_retive_index = create_mask(
verified_seq_len, num_draft_token, index, parent_list, depth
)
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
assert torch.allclose(positions, c_positions), "positions has error."
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
def test_build_tree_kernel_efficient():
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
score_list = [
......@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
depth = 4
num_draft_token = 8
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id=verified_id,
score_list=score_list,
token_list=token_list,
parents_list=parents_list,
seq_lens=seq_lens,
seq_lens_sum=torch.sum(seq_lens).item(),
topk=topk,
spec_steps=depth,
num_verify_tokens=num_draft_token,
)
)
from sglang.srt.utils import first_rank_print
first_rank_print("=========== build tree kernel ==========")
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print(f"{position=}", flush=True)
first_rank_print(f"{retrive_index=}", flush=True)
first_rank_print(f"{retrive_cum_len=}", flush=True)
first_rank_print(f"{draft_tokens=}", flush=True)
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [
[0, -1, -1, -1, -1, -1],
[0, 2, 4, 6, -1, -1],
[0, 1, 3, 5, 7, -1],
[8, -1, -1, -1, -1, -1],
[8, 9, 10, -1, -1, -1],
[8, 9, 12, -1, -1, -1],
[8, 9, 13, -1, -1, -1],
[8, 9, 11, 14, 15, -1],
]
assert retrive_cum_len.tolist() == [0, 3, 8]
assert draft_tokens.tolist() == [
29974,
29896,
29906,
29889,
29974,
29946,
29896,
29946,
13,
13,
22550,
4136,
16492,
8439,
29871,
29941,
]
(
tree_mask,
position,
......@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
if __name__ == "__main__":
test_build_tree_kernel_efficient()
test_build_tree_kernel()
......@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
import logging
logger = logging.getLogger(__name__)
class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker):
......@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
server_args = model_runner.server_args
assert self.disable_padding
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.num_tokens_per_bs = server_args.speculative_eagle_topk
......@@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner:
set_global_graph_memory_pool(graph.pool())
return graph, out
def _postprocess_output_to_raw_bs(self, out, raw_bs):
score_list, token_list, parents_list = out
score_list = [x[:raw_bs] for x in score_list]
token_list = [x[:raw_bs] for x in token_list]
parents_list = [x[:raw_bs] for x in parents_list]
return (score_list, token_list, parents_list)
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size
......@@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner:
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
self.positions.zero_()
num_tokens = bs * self.num_tokens_per_bs
# Common inputs
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
......@@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner:
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# Attention backend
if bs != raw_bs:
forward_batch.batch_size = bs
forward_batch.seq_lens = self.seq_lens[:bs]
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, forward_batch.batch_size
forward_batch, bs
)
# Replay
self.graphs[bs].replay()
out = self.output_buffers[bs]
if bs != raw_bs:
out = self._postprocess_output_to_raw_bs(out, raw_bs)
forward_batch.batch_size = raw_bs
forward_batch.positions = self.positions[:raw_num_token]
forward_batch.seq_lens = self.seq_lens[:raw_bs]
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
return self.output_buffers[bs]
return out
import logging
import os
import time
from contextlib import contextmanager
from typing import List, Optional, Tuple
import torch
from huggingface_hub import snapshot_download
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk,
select_top_k_tokens,
)
from sglang.srt.utils import get_available_gpu_memory
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
if is_cuda_available():
from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__)
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
yield
class EAGLEWorker(TpModelWorker):
def __init__(
......@@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker):
self.hot_token_id = None
# Init draft worker
super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=nccl_port,
dp_rank=dp_rank,
is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
with empty_context():
super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=nccl_port,
dp_rank=dp_rank,
is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
......@@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker):
self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id]
self.draft_model_runner.model.set_embed_and_head(embed, head)
# Init attention backend and cuda graphs
self.draft_model_runner.server_args.disable_cuda_graph = (
backup_disable_cuda_graph
)
self.init_attention_backend()
self.init_cuda_graphs()
self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.init_attention_backend()
self.init_cuda_graphs()
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
......@@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker):
)
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner,
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend,
)
self.draft_attn_backend = TritonMultiStepDraftBackend(
self.model_runner,
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.model_runner,
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
else:
raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
)
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph:
return
# Capture draft
tic = time.time()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
# Capture extend
if self.draft_extend_attn_backend:
raise NotImplementedError()
@property
def draft_model_runner(self):
return self.model_runner
......@@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker):
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
the final output batch doesn't have the same state as the input.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
the final output batch have the same state as the input.
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
......@@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
"""
assert not batch.spec_algorithm.is_none()
if batch.forward_mode.is_decode():
spec_info, to_free_cache_loc = self.draft(batch)
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info, to_free_cache_loc = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
# if it is None, means all requests are finished
if batch.spec_info.verified_id is not None:
self.forward_draft_extend_after_decode(batch)
# If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None:
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verify_output.verified_id,
model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu),
)
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids, _ = (
self.target_worker.forward_batch_generation(
ForwardBatch.init_new(
model_worker_batch, self.target_worker.model_runner
)
)
)
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else:
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids
)
return logits_output, next_token_ids, bid, 0
def forward_target_extend(
......@@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker):
num_seqs = batch.batch_size()
spec_info = batch.spec_info
# Accumulate penalty
if batch.sampling_info.penalizer_orchestrator.is_required:
# This is a relaxed version of penalties for speculative decoding.
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
spec_info.verified_id.to(torch.int64)
)
# Allocate cache locations
out_cache_loc = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps
......@@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
batch.sampling_info.is_all_greedy,
)
return ret, out_cache_loc
def draft_forward(self, forward_batch: ForwardBatch):
......@@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker):
token_list.append(tree_info[1])
parents_list.append(tree_info[2])
# we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
# We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
if i == self.speculative_num_steps - 1:
break
......@@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker):
spec_info.hidden_states = hidden_states
# Run forward
logits_output = self.model_runner.model.forward(
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
self._detect_nan_if_needed(logits_output)
......@@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker):
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
logits_output.next_token_logits = logits_output.next_token_logits[
res.accepeted_indices_cpu
]
logits_output.hidden_states = logits_output.hidden_states[
res.accepeted_indices_cpu
res.accepeted_indices
]
logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input
......@@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker):
batch_next_token_ids,
]
# Add output logprobs to the request.
# Add output logprobs to the request
pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist()
......@@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker):
self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens
# Backup fileds that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
# Prepare metadata
batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# We don't need logprob for this extend.
original_return_logprob = batch.return_logprob
batch.return_logprob = False
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
# Run
logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.return_logprob = original_return_logprob
batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup
batch.return_logprob = return_logprob_backup
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
......@@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
......
......@@ -36,6 +36,7 @@ import tempfile
import threading
import time
import warnings
from contextlib import contextmanager
from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec
......@@ -1577,6 +1578,16 @@ def next_power_of_2(n: int):
setattr(triton, "next_power_of_2", next_power_of_2)
@contextmanager
def empty_context(*args, **kwargs):
try:
# Setup code goes here
yield
finally:
# Cleanup code goes here
pass
def add_prefix(name: str, prefix: str) -> str:
"""Add a weight path prefix to a module name.
......
......@@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa
# For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12
# reinstall sgl-kernel
pip install sgl-kernel==0.0.5.post1 --force-reinstall --no-deps
......@@ -36,8 +36,8 @@ template <
typename DType,
typename IdType>
__global__ void TreeSpeculativeSamplingTargetOnly(
IdType* predicts,
IdType* accept_index,
IdType* predicts, // mutable
IdType* accept_index, // mutable
IdType* accept_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
......@@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
template <typename DType, typename IdType>
cudaError_t TreeSpeculativeSamplingTargetOnly(
IdType* predicts,
IdType* output_token_ids,
IdType* predicts, // mutable
IdType* output_token_ids, // mutable
IdType* output_accepted_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
......
......@@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase):
def _test_acc_length(self, engine):
prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
] * 5
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
] * 5 # test batched generation
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
output = output[0]
......
......@@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
"--enable-torch-compile",
"--disable-cuda-graph",
"--cuda-graph-max-bs",
"2",
"4",
"--enable-flashinfer-mla",
"--flashinfer-mla-disable-ragged",
]
......@@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase):
other_args.extend(
[
"--cuda-graph-max-bs",
"2",
"4",
"--disable-radix",
"--enable-torch-compile",
"--torch-compile-max-bs",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment