Unverified Commit 20a6c0a6 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Beta spec-overlap for EAGLE (#11398)


Co-authored-by: default avatarLianmin Zheng <15100009+merrymercy@users.noreply.github.com>
Co-authored-by: default avatarHanming Lu <69857889+hanming-lu@users.noreply.github.com>
parent 47c606d3
......@@ -55,6 +55,25 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()
def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers of verify attention kernels that needs to be filled after draft.
Typically, these are tree mask and position buffers.
"""
return [None, None]
def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
"""
Update the buffers returned by get_verify_fill_after_draft_buffers if needed.
Here, we need to redo the computation of all metadata of the attention backend
that depends on tree mask and position buffers.
"""
raise NotImplementedError()
def forward(
self,
q: torch.Tensor,
......
......@@ -29,7 +29,6 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
get_int_env_var,
......
......@@ -162,6 +162,8 @@ class TritonAttnBackend(AttentionBackend):
# Initialize forward metadata
self.forward_metadata: ForwardMetadata = None
self.cuda_graph_custom_mask = None
def get_num_kv_splits(
self,
num_kv_splits: torch.Tensor,
......@@ -755,6 +757,19 @@ class TritonAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1
def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.
Typically, these are tree mask and position buffers.
"""
return [self.cuda_graph_custom_mask, None]
def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass
def forward_extend(
self,
q: torch.Tensor,
......
......@@ -384,6 +384,7 @@ class LogitsProcessor(nn.Module):
if (
logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify()
or logits_metadata.forward_mode.is_draft_extend_v2()
):
pruned_states = hidden_states
if aux_hidden_states is not None:
......
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _resolve_future_token_ids(input_ids, future_token_ids_map):
......@@ -27,6 +34,7 @@ class FutureMap:
self,
max_running_requests: int,
device: torch.device,
spec_algo: Optional[SpeculativeAlgorithm] = None,
):
self.future_ct = 0
# A factor of 3 is used to avoid collision in the circular buffer.
......@@ -34,9 +42,51 @@ class FutureMap:
# A factor of 5 is used to ensure the buffer is large enough.
self.future_buffer_len = max_running_requests * 5
self.device = device
self.spec_algo = spec_algo
self.buf_initialized = False
if self.spec_algo.is_none():
self.token_ids_buf = torch.empty(
(self.future_buffer_len,), dtype=torch.int64, device=self.device
)
def _lazy_init_buf(self, draft_input: EagleDraftInput):
if self.buf_initialized or not self.spec_algo.is_eagle():
return
self.buf_initialized = True
# get the template for each tensor
topk_p0 = draft_input.topk_p[0]
topk_index0 = draft_input.topk_index[0]
hidden_states0 = draft_input.hidden_states[0]
verified_id0 = draft_input.verified_id[0]
new_seq_lens0 = draft_input.new_seq_lens[0]
self.token_ids_buf = torch.empty(
(self.future_buffer_len,), dtype=torch.int64, device=self.device
self.topk_p_buf = torch.empty(
(self.future_buffer_len, *topk_p0.shape),
dtype=topk_p0.dtype,
device=self.device,
)
self.topk_index_buf = torch.empty(
(self.future_buffer_len, *topk_index0.shape),
dtype=topk_index0.dtype,
device=self.device,
)
self.hidden_states_buf = torch.empty(
(self.future_buffer_len, *hidden_states0.shape),
dtype=hidden_states0.dtype,
device=self.device,
)
self.verified_id_buf = torch.empty(
(self.future_buffer_len, *verified_id0.shape),
dtype=verified_id0.dtype,
device=self.device,
)
self.new_seq_lens_buf = torch.empty(
(self.future_buffer_len, *new_seq_lens0.shape),
dtype=new_seq_lens0.dtype,
device=self.device,
)
def alloc_future_indices(self, bs: int) -> FutureIndices:
......@@ -49,7 +99,32 @@ class FutureMap:
return FutureIndices(indices=indices, interval=slice(start, end))
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
if self.spec_algo.is_eagle():
# TODO(lsyin): write future indices into spec_info.future_indices
draft_input: EagleDraftInput = model_worker_batch.spec_info
if draft_input is None:
# FIXME(lsyin): No future exists, only for prefill batch, not compatible with mixed mode
return
indices = draft_input.future_indices.indices
draft_input.topk_p = self.topk_p_buf[indices]
draft_input.topk_index = self.topk_index_buf[indices]
draft_input.hidden_states = self.hidden_states_buf[indices]
draft_input.verified_id = self.verified_id_buf[indices]
draft_input.new_seq_lens = self.new_seq_lens_buf[indices]
else:
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
def store_to_map(self, future_indices: FutureIndices, next_token_ids: torch.Tensor):
self.token_ids_buf[future_indices.interval] = next_token_ids
def store_to_map(
self, future_indices: FutureIndices, batch_result: GenerationBatchResult
):
intv = future_indices.interval
if self.spec_algo.is_eagle():
draft_input: EagleDraftInput = batch_result.next_draft_input
self._lazy_init_buf(draft_input)
self.topk_p_buf[intv] = draft_input.topk_p
self.topk_index_buf[intv] = draft_input.topk_index
self.hidden_states_buf[intv] = draft_input.hidden_states
self.verified_id_buf[intv] = draft_input.verified_id
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
else:
self.token_ids_buf[intv] = batch_result.next_token_ids
......@@ -61,8 +61,12 @@ from sglang.srt.mem_cache.allocator import (
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.common import (
alloc_for_decode,
alloc_for_extend,
alloc_token_slots,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
......@@ -71,6 +75,7 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.common import next_power_of_2
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
......@@ -1067,6 +1072,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self):
return len(self.reqs) == 0
def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
bs = self.batch_size()
assert self.spec_info.is_draft_input()
draft_input: EagleDraftInput = self.spec_info
# FIXME(lsyin): now implementation does not enable over-allocation
# Now seq_lens and allocate_lens are correct
self.maybe_wait_verify_done()
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
assign_req_to_token_pool[(bs,)](
self.req_pool_indices,
self.req_to_token_pool.req_to_token,
draft_input.allocate_lens,
new_allocate_lens,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
draft_input.allocate_lens = new_allocate_lens
# FIXME(lsyin): remove seq_lens_sum calculation
self.seq_lens_cpu = self.seq_lens.cpu()
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
......@@ -1507,15 +1544,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.model_config.vocab_size,
)
@property
def is_v2_eagle(self):
# FIXME: finally deprecate is_v2_eagle
return self.enable_overlap and self.spec_algorithm.is_eagle()
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
bs = len(self.reqs)
if (
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_ngram()
):
if self.is_v2_eagle:
# FIXME(lsyin): make this sync optional
self.allocate_for_eagle_v2()
if not self.spec_algorithm.is_none():
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return
......@@ -1566,11 +1608,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs
def maybe_wait_verify_done(self):
if self.is_v2_eagle:
from sglang.srt.speculative.eagle_info import EagleDraftInput
draft_input: EagleDraftInput = self.spec_info
if draft_input.verify_done is not None:
draft_input.verify_done.synchronize()
def filter_batch(
self,
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
keep_indices: Optional[List[int]] = None,
):
# FIXME(lsyin): used here to get the correct seq_lens
# The batch has been launched but we need it verified to get correct next batch info
self.maybe_wait_verify_done()
if keep_indices is None:
if isinstance(chunked_req_to_exclude, Req):
chunked_req_to_exclude = [chunked_req_to_exclude]
......@@ -1633,6 +1687,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
def merge_batch(self, other: "ScheduleBatch"):
# NOTE: in v2 eagle mode, we do not need wait verify here because
# 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
# 2) other batch is always decode, which is finished in previous step
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
......@@ -1757,6 +1815,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
is_prefill_only=self.is_prefill_only,
seq_lens_cpu=self.seq_lens_cpu,
enable_overlap=self.enable_overlap,
)
def _evict_tree_cache_if_needed(self, num_tokens: int):
......
......@@ -148,13 +148,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import (
process_tracing_init,
......@@ -219,6 +216,14 @@ class GenerationBatchResult:
forward_batch: Optional[ForwardBatch] = None
future_indices: Optional[FutureIndices] = None
# FIXME(lsyin): maybe move to <BetterPlace> ?
# sync path: forward stream -> output processor
accept_lens: Optional[torch.Tensor] = None
last_batch_allocate_lens: Optional[torch.Tensor] = None
# relay path: forward stream -> next step forward
next_draft_input: Optional[EagleDraftInput] = None
def copy_to_cpu(self, return_logprob: bool = False):
"""Copy tensors to CPU in overlap scheduling.
Only the tensors which are needed for processing results are copied,
......@@ -238,6 +243,15 @@ class GenerationBatchResult:
"cpu", non_blocking=True
)
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
if self.accept_lens is not None:
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
if self.last_batch_allocate_lens is not None:
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to(
"cpu", non_blocking=True
)
self.copy_done.record()
@classmethod
......@@ -273,48 +287,6 @@ class Scheduler(
):
"""A scheduler that manages a tensor parallel GPU worker."""
def launch_draft_worker(
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
):
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
self.draft_worker = EAGLEWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_standalone():
from sglang.srt.speculative.standalone_worker import StandaloneWorker
self.draft_worker = StandaloneWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGRAMWorker
self.draft_worker = NGRAMWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
def __init__(
self,
server_args: ServerArgs,
......@@ -454,6 +426,7 @@ class Scheduler(
)
# Launch a draft worker for speculative decoding
self.launch_draft_worker(
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
)
......@@ -683,6 +656,51 @@ class Scheduler(
]
)
def launch_draft_worker(
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
):
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
self.draft_worker = WorkerClass(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_standalone():
from sglang.srt.speculative.standalone_worker import StandaloneWorker
self.draft_worker = StandaloneWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGRAMWorker
self.draft_worker = NGRAMWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference:
......@@ -965,7 +983,9 @@ class Scheduler(
self.device
).stream(self.copy_stream)
self.future_map = FutureMap(self.max_running_requests, self.device)
self.future_map = FutureMap(
self.max_running_requests, self.device, self.spec_algorithm
)
self.batch_record_buf = [None] * 2
self.batch_record_ct = 0
......@@ -2096,7 +2116,7 @@ class Scheduler(
batch_or_worker_batch = batch
if self.spec_algorithm.is_none():
if self.enable_overlap or self.spec_algorithm.is_none():
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch()
......@@ -2120,39 +2140,49 @@ class Scheduler(
if batch.sampling_info.grammars is not None:
model_worker_batch.delay_sample_launch = True
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
model_worker_batch
)
# FIXME(lsyin): maybe move this to forward_batch_generation
batch_result.copy_done = torch.get_device_module(
self.device
).Event()
if not model_worker_batch.delay_sample_launch:
self.future_map.store_to_map(
future_indices, batch_result.next_token_ids
)
self.future_map.store_to_map(future_indices, batch_result)
batch_result.copy_to_cpu()
else:
batch_result.future_indices = future_indices
# FIXME(lsyin): move this assignment elsewhere
maybe_future_next_token_ids = -future_indices.indices
future_indices_or_next_token_ids = -future_indices.indices
if batch.is_v2_eagle:
# FIXME(lsyin): tmp code for eagle v2
# We only keep future indices for next draft input
batch.spec_info = batch_result.next_draft_input
batch.spec_info.future_indices = future_indices
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
# # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
# allocate_lens=batch_result.next_draft_input.allocate_lens,
# )
# The future value, usually for next batch preparation
# Current implementation strictly synchronizes the seq_lens
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
else:
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
maybe_future_next_token_ids = batch_result.next_token_ids
if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.update_spec_metrics(
batch.batch_size(), batch_result.num_accepted_tokens
)
future_indices_or_next_token_ids = batch_result.next_token_ids
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
# NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
# which can probably be replaced by future_indices later [TODO(lsyin)].
# we shall still keep the original outputs, e.g. next_token_ids
# in the GenerationBatchOutput for processing after copy_done.
batch.output_ids = maybe_future_next_token_ids
batch.output_ids = future_indices_or_next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
......@@ -2200,7 +2230,7 @@ class Scheduler(
tmp_result.forward_batch,
)
future_indices = tmp_result.future_indices
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
self.future_map.store_to_map(future_indices, tmp_result)
tmp_result.copy_to_cpu()
self.result_queue.appendleft((tmp_batch, tmp_result))
......
......@@ -69,7 +69,7 @@ class SchedulerMetricsMixin:
kv_events_config, self.attn_dp_rank
)
def update_spec_metrics(self, bs: int, num_accepted_tokens: int):
def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int):
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens
......
from __future__ import annotations
import logging
import threading
import time
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
......@@ -200,6 +199,28 @@ class SchedulerOutputProcessorMixin:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def hacky_process_eagle_overlap_result(
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
):
# TODO(lsyin): try use a copy stream to share SMs with forward
# FIXME(lsyin): better organize this token free logic in eagle-overlap
last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist()
accept_lens_cpu = result.accept_lens.tolist()
next_token_ids = result.next_token_ids.tolist()
predict_tokens = []
num_draft_tokens = self.draft_worker.speculative_num_draft_tokens
for i, req in enumerate(batch.reqs):
predict_tokens.append(
next_token_ids[
i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i]
]
)
# FIXME(lsyin): move this update elsewhere
req.spec_verify_ct += 1
return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens
def process_batch_result_decode(
self: Scheduler,
batch: ScheduleBatch,
......@@ -220,6 +241,17 @@ class SchedulerOutputProcessorMixin:
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
elif batch.is_v2_eagle:
(
last_batch_allocate_lens_cpu,
accept_lens_cpu,
next_token_ids,
) = self.hacky_process_eagle_overlap_result(result, batch)
result.num_accepted_tokens = sum(accept_lens_cpu)
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result
if not self.spec_algorithm.is_none():
self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens)
self.token_to_kv_pool_allocator.free_group_begin()
......@@ -227,29 +259,74 @@ class SchedulerOutputProcessorMixin:
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req: Req
if req.is_retracted:
continue
if self.enable_overlap and req.finished():
# Free the one extra delayed token
if self.page_size == 1:
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
else:
# Only free when the extra token is in a new page
if (
len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0:
if batch.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker_v2 import (
free_spec_dec_tokens_page_size_1,
)
free_spec_dec_tokens_page_size_1(
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
req,
last_batch_allocate_lens_cpu[i],
None,
)
else:
# Free the one extra delayed token
self.token_to_kv_pool_allocator.free(
batch.out_cache_loc[i : i + 1]
)
else:
if batch.spec_algorithm.is_eagle():
# TODO(lsyin): support eagle with page_size > 1
raise NotImplementedError()
else:
if (
len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0:
# Only free when the extra token is in a new page
self.token_to_kv_pool_allocator.free(
batch.out_cache_loc[i : i + 1]
)
continue
if batch.spec_algorithm.is_none():
# speculative worker will solve the output_ids in speculative decoding
req.output_ids.append(next_token_id)
elif batch.is_v2_eagle:
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding
# !!!unify the logic here!!!
req.output_ids.extend(next_token_id)
req.check_finished()
if req.finished():
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
# FIXME(lsyin): fix the messy logic here
# 1) when not overlap (v2 impl), we free the extra tokens in the req
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch
from sglang.srt.speculative.eagle_worker_v2 import (
free_spec_dec_tokens_page_size_1,
)
new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1
# FIXME(lsyin): remove this assert
assert new_seq_len == int(
batch.seq_lens_cpu[i] + accept_lens_cpu[i]
), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}"
free_spec_dec_tokens_page_size_1(
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
req,
last_batch_allocate_lens_cpu[i],
new_seq_len,
)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
if not self.decode_offload_manager.offload_kv_cache(req):
......
......@@ -231,12 +231,21 @@ class TpModelWorker:
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
forward_batch: Optional[ForwardBatch] = None,
is_verify: bool = False,
skip_attn_backend_init=False,
) -> GenerationBatchResult:
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
# FIXME(lsyin): maybe remove skip_attn_backend_init in forward_batch_generation,
# which requires preparing replay to always be in this function
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
if model_worker_batch is not None:
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
else:
# FIXME(lsyin): unify the interface of forward_batch
assert forward_batch is not None
pp_proxy_tensors = None
if not self.pp_group.is_first_rank:
......@@ -248,7 +257,9 @@ class TpModelWorker:
if self.pp_group.is_last_rank:
logits_output, can_run_cuda_graph = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
skip_attn_backend_init=skip_attn_backend_init,
)
batch_result = GenerationBatchResult(
logits_output=logits_output,
......@@ -290,6 +301,7 @@ class TpModelWorker:
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
skip_attn_backend_init=skip_attn_backend_init,
)
return GenerationBatchResult(
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
......
......@@ -678,8 +678,9 @@ class CudaGraphRunner:
capture_hidden_mode_required_by_forward_batch = (
forward_batch.capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
capture_hidden_mode_required_by_spec_info = (
getattr(forward_batch.spec_info, "capture_hidden_mode", None)
or CaptureHiddenMode.NULL
)
capture_hidden_mode_required_for_returning_hidden_states = (
CaptureHiddenMode.FULL
......
......@@ -75,6 +75,8 @@ class ForwardMode(IntEnum):
# Used in speculative decoding: extend a batch in the draft model.
DRAFT_EXTEND = auto()
DRAFT_EXTEND_V2 = auto()
# Split Prefill for PD multiplexing
SPLIT_PREFILL = auto()
......@@ -107,6 +109,10 @@ class ForwardMode(IntEnum):
def is_draft_extend(self):
return self == ForwardMode.DRAFT_EXTEND
def is_draft_extend_v2(self):
# For fixed shape logits output in v2 eagle worker
return self == ForwardMode.DRAFT_EXTEND_V2
def is_extend_or_draft_extend_or_mixed(self):
return (
self == ForwardMode.EXTEND
......
......@@ -312,6 +312,7 @@ class ServerArgs:
nsa_decode: str = "fa3"
# Speculative decoding
enable_beta_spec: bool = False
speculative_algorithm: Optional[str] = None
speculative_draft_model_path: Optional[str] = None
speculative_draft_model_revision: Optional[str] = None
......@@ -1103,11 +1104,19 @@ class ServerArgs:
)
if self.max_running_requests is None:
self.max_running_requests = 48
self.disable_overlap_schedule = True
logger.warning(
"Overlap scheduler is disabled because of using "
"eagle speculative decoding."
)
if self.speculative_algorithm == "EAGLE" and self.enable_beta_spec:
self.disable_overlap_schedule = False
logger.warning(
"Beta spec is enabled for eagle speculative decoding and overlap schedule is turned on."
)
if not self.enable_beta_spec:
self.disable_overlap_schedule = True
logger.warning(
"Overlap scheduler is disabled because of using eagle3 and standalone speculative decoding."
)
if self.enable_mixed_chunk:
self.enable_mixed_chunk = False
logger.warning(
......@@ -2127,6 +2136,7 @@ class ServerArgs:
)
# Speculative decoding
parser.add_argument("--enable-beta-spec", action="store_true")
parser.add_argument(
"--speculative-algorithm",
type=str,
......
import logging
from copy import copy
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import ClassVar, List, Optional, Tuple
import torch
import torch.nn.functional as F
......@@ -10,6 +10,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.overlap_utils import FutureIndices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.common import (
......@@ -18,16 +19,20 @@ from sglang.srt.mem_cache.common import (
get_last_loc,
)
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftInputV2Mixin,
EagleVerifyInputV2Mixin,
)
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN,
TREE_SPEC_KERNEL_AVAILABLE,
_generate_simulated_accept_index,
align_evict_mask_to_page_size,
assign_req_to_token_pool,
create_accept_length_filter,
create_extend_after_decode_spec_info,
filter_finished_cache_loc_kernel,
generate_simulated_accept_index,
get_src_tgt_cache_loc,
get_target_cache_loc,
)
......@@ -47,7 +52,7 @@ logger = logging.getLogger(__name__)
@dataclass
class EagleVerifyInput(SpecInput):
class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
draft_token: torch.Tensor
custom_mask: torch.Tensor
positions: torch.Tensor
......@@ -338,7 +343,7 @@ class EagleVerifyInput(SpecInput):
if SIMULATE_ACC_LEN > 0.0:
# Do simulation
accept_index = _generate_simulated_accept_index(
accept_index = generate_simulated_accept_index(
accept_index=accept_index,
predict=predict, # mutable
accept_length=accept_length, # mutable
......@@ -568,7 +573,7 @@ class EagleVerifyInput(SpecInput):
@dataclass
class EagleDraftInput(SpecInput):
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
# The inputs for decode
# shape: (b, topk)
topk_p: torch.Tensor = None
......@@ -598,6 +603,15 @@ class EagleDraftInput(SpecInput):
seq_lens_for_draft_extend_cpu: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None
# Inputs for V2 overlap worker
future_indices: Optional[FutureIndices] = None
allocate_lens: Optional[torch.Tensor] = None
new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None
# FIXME(lsyin): remove this hack
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT)
......@@ -703,6 +717,11 @@ class EagleDraftInput(SpecInput):
return kv_indices, cum_kv_seq_len, qo_indptr, None
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
if self.future_indices is not None:
self.future_indices.indices = self.future_indices.indices[new_indices]
self.allocate_lens = self.allocate_lens[new_indices]
return
if has_been_filtered:
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
......@@ -722,6 +741,18 @@ class EagleDraftInput(SpecInput):
self.verified_id = self.verified_id[new_indices]
def merge_batch(self, spec_info: "EagleDraftInput"):
if self.future_indices is not None:
assert spec_info.future_indices is not None
self.future_indices = FutureIndices(
indices=torch.cat(
[self.future_indices.indices, spec_info.future_indices.indices]
)
)
self.allocate_lens = torch.cat(
[self.allocate_lens, spec_info.allocate_lens]
)
return
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
......
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN,
generate_simulated_accept_index,
)
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
if is_cuda():
from sgl_kernel import (
top_k_renorm_prob,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
elif is_hip():
from sgl_kernel import verify_tree_greedy
@triton.jit
def assign_draft_cache_locs_page_size_1(
req_pool_indices,
req_to_token,
seq_lens,
out_cache_loc,
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
pid = tl.program_id(axis=0)
copy_len = topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
# Copy from req_to_token to out_cache_loc
kv_start = tl.load(seq_lens + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = copy_offset < copy_len
data = tl.load(token_pool + kv_start + copy_offset, mask=mask)
tl.store(out_cache_ptr + copy_offset, data, mask=mask)
@dataclass
class EagleDraftInputV2Mixin:
def prepare_for_v2_draft(
self: EagleDraftInput,
req_to_token_pool: ReqToTokenPool,
batch: ModelWorkerBatch,
cuda_graph_runner: EAGLEDraftCudaGraphRunner,
draft_model_runner: ModelRunner,
topk: int,
num_steps: int,
):
bs = len(batch.seq_lens)
# Assign cache locations
batch.out_cache_loc = torch.empty(
(bs * topk * num_steps,),
dtype=torch.int64,
device=batch.input_ids.device,
)
# FIXME(lsyin): align with the default code path
assign_draft_cache_locs_page_size_1[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
topk,
num_steps,
)
# Get a forward batch
batch.capture_hidden_mode = CaptureHiddenMode.LAST
self.positions = batch.seq_lens.repeat_interleave(topk, dim=0)
forward_batch = ForwardBatch.init_new(batch, draft_model_runner)
can_cuda_graph = cuda_graph_runner and cuda_graph_runner.can_run(forward_batch)
return forward_batch, can_cuda_graph
def prepare_for_extend_to_fill_draft_kvcache(
self,
batch: ModelWorkerBatch,
predict: torch.Tensor,
num_draft_tokens: int,
draft_model_runner: Any,
):
seq_lens_cpu_backup = batch.seq_lens_cpu
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
batch.spec_info = self
batch.input_ids = predict
batch.seq_lens = batch.seq_lens + num_draft_tokens
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
batch.seq_lens_sum += extend_num_tokens
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
batch.extend_prefix_lens = seq_lens_cpu_backup.tolist()
batch.extend_prefix_lens_cpu = seq_lens_cpu_backup
batch.extend_num_tokens = extend_num_tokens
batch.capture_hidden_mode = CaptureHiddenMode.FULL
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
forward_batch = ForwardBatch.init_new(batch, draft_model_runner)
draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
return forward_batch
@dataclass
class EagleVerifyInputV2Mixin:
def prepare_for_v2_verify(
self: EagleVerifyInput,
req_to_token_pool: ReqToTokenPool,
batch: ModelWorkerBatch,
target_worker: TpModelWorker,
):
# Assign cache locations
bs = len(batch.req_pool_indices)
batch.input_ids = self.draft_token
device = batch.input_ids.device
batch.out_cache_loc = torch.empty(
(bs * self.draft_token_num,),
dtype=torch.int64,
device=device,
)
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
# Get a forward batch
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.capture_hidden_mode = CaptureHiddenMode.FULL
verify_forward_batch = ForwardBatch.init_new(batch, target_worker.model_runner)
# Run attention backend plan and cuda graph preparation
can_run_cuda_graph = bool(
target_worker.model_runner.graph_runner
and target_worker.model_runner.graph_runner.can_run(verify_forward_batch)
)
if can_run_cuda_graph:
target_worker.model_runner.graph_runner.replay_prepare(verify_forward_batch)
else:
target_worker.model_runner.attn_backend.init_forward_metadata(
verify_forward_batch
)
return verify_forward_batch, can_run_cuda_graph
def sample(
self: EagleVerifyInput,
batch: ModelWorkerBatch,
logits_output: LogitsProcessorOutput,
):
"""
Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information).
"""
bs = len(batch.seq_lens)
sampling_info = batch.sampling_info
next_token_logits = logits_output.next_token_logits
device = batch.input_ids.device
candidates = self.draft_token.reshape(bs, self.draft_token_num)
predict = torch.zeros(
(bs * (self.spec_steps + 1),), dtype=torch.int32, device=device
)
accept_index = torch.full(
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device=device
)
accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
# Sample tokens
if sampling_info.is_all_greedy:
target_predict = torch.argmax(next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num)
verify_tree_greedy(
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates,
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict,
)
else:
# Apply temperature and get target probs
expanded_temperature = torch.repeat_interleave(
sampling_info.temperatures, self.draft_token_num, dim=0
) # (bs * num_draft_tokens, 1)
target_probs = F.softmax(
next_token_logits / expanded_temperature, dim=-1
) # (bs * num_draft_tokens, vocab_size)
target_probs = top_k_renorm_prob(
target_probs,
torch.repeat_interleave(
sampling_info.top_ks, self.draft_token_num, dim=0
),
) # (bs * num_draft_tokens, vocab_size)
target_probs = top_p_renorm_prob(
target_probs,
torch.repeat_interleave(
sampling_info.top_ps, self.draft_token_num, dim=0
),
)
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
# This is currently not used
draft_probs = torch.empty_like(target_probs)
# coins for rejection sampling
coins = torch.rand_like(candidates, dtype=torch.float32, device=device)
# coins for final sampling
coins_for_final_sampling = torch.rand(
(bs,), dtype=torch.float32, device=device
)
tree_speculative_sampling_target_only(
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
candidates=candidates,
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
uniform_samples=coins,
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True,
)
if SIMULATE_ACC_LEN > 0:
# Do simulation
accept_index = generate_simulated_accept_index(
accept_index=accept_index,
predict=predict, # mutable
accept_length=accept_length, # mutable
simulate_acc_len=SIMULATE_ACC_LEN,
bs=bs,
spec_steps=self.draft_token_num,
)
# Include the bonus token
accept_length.add_(1)
return predict, accept_length, accept_index
def build_tree_kernel_efficient_tmp(
verified_id: torch.Tensor,
parent_list: List[torch.Tensor],
top_scores_index: torch.Tensor,
draft_tokens: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_tokens: int,
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
tree_mask_buf: Optional[torch.Tensor] = None,
position_buf: Optional[torch.Tensor] = None,
):
# TODO(lsyin): make it compatible with default code path
# TODO(lsyin): support cuda graph graph padding for eagle
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
bs = seq_lens.numel()
device = seq_lens.device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
if tree_mask_buf is not None:
tree_mask = tree_mask_buf
if tree_mask_mode == TreeMaskMode.QLEN_ONLY:
tree_mask.fill_(True)
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
tree_mask.fill_(0)
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
tree_mask.fill_(True)
else:
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
tree_mask = torch.full(
(num_verify_tokens * bs * num_verify_tokens,),
True,
dtype=torch.bool,
device=device,
)
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
tree_mask = torch.zeros(
(num_verify_tokens * bs,),
dtype=packed_dtypes[packed_dtype_idx],
device=device,
)
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
else:
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
retrive_buf = torch.full(
(3, bs, num_verify_tokens), -1, device=device, dtype=torch.long
)
retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf
# position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9]
if position_buf is not None:
positions = position_buf
else:
positions = torch.empty(
(bs * num_verify_tokens,), device=device, dtype=torch.long
)
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
sgl_build_tree_kernel_efficient(
parent_list,
top_scores_index,
seq_lens,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
spec_steps,
num_verify_tokens,
tree_mask_mode,
)
return (
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
)
@torch.compile(dynamic=True)
def select_top_k_tokens_tmp(
i: int,
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: torch.Tensor,
scores: torch.Tensor,
topk: int,
):
# FIXME(lsyin): remove this duplicate code
if i == 0:
# The first step after extend
input_ids = topk_index.flatten()
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
scores = topk_p # shape: (b, topk)
tree_info = (
topk_p.unsqueeze(1), # shape: (b, 1, topk)
topk_index, # shape: (b, topk)
torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device)
.unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
)
else:
# The later decode steps
expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
) # (b, topk)
scores = topk_cs_p # shape: (b, topk)
topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device=hidden_states.device
).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
tree_info = (
expand_scores, # shape: (b, topk, topk)
topk_index, # shape: (b, topk * topk)
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
)
return input_ids, hidden_states, scores, tree_info
@triton.jit
def fill_new_verified_id(
verified_id,
accept_lens,
new_verified_id,
num_draft_tokens: tl.constexpr,
):
# NOTE: we cannot fuse any in-place operations of `accept_lens` inside this kernel
# because this kernel reads accept_lens
pid = tl.program_id(axis=0)
accept_length = tl.load(accept_lens + pid)
verified_id_idx = num_draft_tokens * pid + accept_length - 1
verified_id_data = tl.load(verified_id + verified_id_idx)
tl.store(new_verified_id + pid, verified_id_data)
@triton.jit
def fill_accepted_out_cache_loc(
accept_index,
out_cache_loc,
accepted_out_cache_loc,
size_upper: tl.constexpr,
):
pid = tl.program_id(axis=0)
offset = tl.arange(0, size_upper)
masks = (tl.load(accept_index + offset, offset < pid, other=-1) != -1).to(tl.int64)
dst = tl.sum(masks)
src = tl.load(accept_index + pid)
if src > -1:
value = tl.load(out_cache_loc + src)
tl.store(accepted_out_cache_loc + dst, value)
@triton.jit
def assign_extend_cache_locs(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len: tl.constexpr,
bs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0)
kv_start = tl.load(start_offset + pid)
kv_end = tl.load(end_offset + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
length_offset = tl.arange(0, bs_upper)
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
out_offset = tl.sum(end - start, axis=0)
out_cache_ptr = out_cache_loc + out_offset
load_offset = tl.arange(0, BLOCK_SIZE) + kv_start
save_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = load_offset < kv_end
data = tl.load(token_pool + load_offset, mask=mask)
tl.store(out_cache_ptr + save_offset, data, mask=mask)
load_offset += BLOCK_SIZE
save_offset += BLOCK_SIZE
import logging
from typing import List, Optional
import torch
from torch.cuda import Stream as CudaStream
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.eagle_info_v2 import (
assign_extend_cache_locs,
build_tree_kernel_efficient_tmp,
fill_accepted_out_cache_loc,
fill_new_verified_id,
select_top_k_tokens_tmp,
)
from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.utils.common import fast_topk, next_power_of_2
logger = logging.getLogger(__name__)
class EAGLEWorkerV2(EAGLEWorker):
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
super().__init__(
server_args,
gpu_id,
tp_rank,
dp_rank,
moe_ep_rank,
nccl_port,
target_worker,
)
EagleDraftInput.ALLOC_LEN_PER_DECODE = max(
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
)
self.tree_mask_mode = TreeMaskMode.FULL_MASK
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
# TODO(lsyin): potential bugs with a separate plan stream
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
if model_worker_batch.forward_mode.is_decode():
# FIXME(lsyin): why shall we use spec_info for both draft and verify?
draft_input: EagleDraftInput = model_worker_batch.spec_info
assert draft_input.is_draft_input()
verify_input: EagleVerifyInput = self.draft(model_worker_batch)
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
return batch_output
else:
# Target prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
batch_output = self.target_worker.forward_batch_generation(
model_worker_batch
)
# Draft prefill
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
batch_output.next_draft_input = self.forward_draft_extend(
model_worker_batch,
batch_output.logits_output.hidden_states,
batch_output.next_token_ids,
)
return batch_output
def draft(self, model_worker_batch: ModelWorkerBatch):
draft_input: EagleDraftInput = model_worker_batch.spec_info
forward_batch, can_cuda_graph = draft_input.prepare_for_v2_draft(
self.req_to_token_pool,
model_worker_batch,
self.cuda_graph_runner,
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
# Run draft
if can_cuda_graph:
parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay(
forward_batch,
)
else:
self.draft_attn_backend.init_forward_metadata(forward_batch)
parent_list, top_scores_index, draft_tokens = self.draft_forward(
forward_batch
)
# Build tree mask
# Directly write to cuda graph buffers for verify attn
tree_mask_buf, position_buf = (
self.target_worker.model_runner.attn_backend.get_verify_buffers_to_fill_after_draft()
)
(
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
) = build_tree_kernel_efficient_tmp(
draft_input.verified_id,
parent_list,
top_scores_index,
draft_tokens,
model_worker_batch.seq_lens,
model_worker_batch.seq_lens_sum,
self.topk,
self.speculative_num_steps,
self.speculative_num_draft_tokens,
self.tree_mask_mode,
tree_mask_buf,
position_buf,
)
return EagleVerifyInput(
draft_token=draft_tokens,
custom_mask=tree_mask,
positions=position,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
retrive_cum_len=None,
spec_steps=self.speculative_num_steps,
topk=self.topk,
draft_token_num=self.speculative_num_draft_tokens,
capture_hidden_mode=None,
seq_lens_sum=None,
seq_lens_cpu=None,
)
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
spec_info: EagleDraftInput = forward_batch.spec_info
out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = (
spec_info.topk_p,
spec_info.topk_index,
spec_info.hidden_states,
)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
out_cache_loc = out_cache_loc.reshape(
forward_batch.batch_size, self.topk, self.speculative_num_steps
)
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
self.speculative_num_steps, -1
)
# Return values
score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = []
parents_list: List[torch.Tensor] = []
# Forward multiple steps
scores = None
for i in range(self.speculative_num_steps):
input_ids, hidden_states, scores, tree_info = select_top_k_tokens_tmp(
i, topk_p, topk_index, hidden_states, scores, self.topk
)
score_list.append(tree_info[0])
token_list.append(tree_info[1])
parents_list.append(tree_info[2])
# We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
if i == self.speculative_num_steps - 1:
break
# Set inputs
forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[i]
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
# Run forward
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
self._detect_nan_if_needed(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
hidden_states = logits_output.hidden_states
# Organize the results
score_list = torch.cat(score_list, dim=1).flatten(
1
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
ss_token_list = torch.cat(
token_list, dim=1
) # b, (self.topk + (num_steps-1) * self.topk)
top_scores = torch.topk(
score_list, self.speculative_num_draft_tokens - 1, dim=-1
)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
if len(parents_list) > 1:
parent_list = torch.cat(parents_list[:-1], dim=1)
else:
batch_size = parents_list[0].shape[0]
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
return parent_list, top_scores_index, draft_tokens
def verify(
self,
batch: ModelWorkerBatch,
pre_draft_allocate_lens: torch.Tensor,
):
# Parse args
verify_input: EagleVerifyInput = batch.spec_info
seq_lens_backup = batch.seq_lens
bs = len(batch.seq_lens)
# Batch 1: Target verify
# Prepare for target verify in a separate stream
with self.plan_stream_ctx:
verify_forward_batch, can_run_cuda_graph = (
verify_input.prepare_for_v2_verify(
self.req_to_token_pool,
batch,
self.target_worker,
)
)
# Correct some buffers due to the overlap plan
if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream)
# Some values such as custom_mask and position depend on the output of draft,
# so the previous plan step used the wrong values. Here, we need to run the related
# computation again to update them to the correct values.
self.target_worker.model_runner.attn_backend.update_verify_buffers_to_fill_after_draft(
verify_input,
(
self.target_worker.model_runner.graph_runner.bs
if can_run_cuda_graph
else None
),
)
# Run target verify batch in the main compute stream
forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch=None,
forward_batch=verify_forward_batch,
is_verify=True,
skip_attn_backend_init=True,
)
logits_output = forward_batch_output.logits_output
# Sample
self._detect_nan_if_needed(logits_output)
(
predict,
accept_length,
accept_index,
) = verify_input.sample(batch, logits_output)
new_seq_lens = seq_lens_backup + accept_length
verify_done = torch.cuda.Event()
# Move the accepted tokens to the target KV cache locations
batch.seq_lens = seq_lens_backup
self.move_accepted_tokens_to_target_kvcache(
batch,
accept_index,
accept_length,
)
verify_done.record()
all_verified_id = predict[accept_index]
verified_id = torch.empty_like(accept_length, dtype=torch.int32)
fill_new_verified_id[(bs,)](
all_verified_id,
accept_length,
verified_id,
self.speculative_num_draft_tokens,
)
# Batch 2: Draft extend
draft_input = EagleDraftInput(
hidden_states=logits_output.hidden_states,
)
select_index = (
torch.arange(len(batch.seq_lens), device=self.device)
* self.speculative_num_draft_tokens
+ accept_length
- 1
)
# Prepare for draft extend in a separate stream
with self.plan_stream_ctx:
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
batch,
predict,
self.speculative_num_draft_tokens,
self.draft_model_runner,
)
if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream)
# Run draft extend batch in the main compute stream
draft_logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
# Reorganize the spec info for the next batch
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
select_index
]
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
select_index
]
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
ret_hidden_states = draft_logits_output.hidden_states
# Since seq_lens_backup's tensor is allocated in another stream, we
# need record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
seq_lens_backup.record_stream(torch.cuda.current_stream())
# Construct the return values
next_draft_input = EagleDraftInput(
topk_p=ret_topk_p,
topk_index=ret_topk_index,
hidden_states=ret_hidden_states,
verified_id=verified_id,
new_seq_lens=new_seq_lens,
allocate_lens=pre_draft_allocate_lens,
verify_done=verify_done,
)
return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=predict,
can_run_cuda_graph=can_run_cuda_graph,
next_draft_input=next_draft_input,
accept_lens=accept_length,
last_batch_allocate_lens=pre_draft_allocate_lens,
)
def forward_draft_extend(
self,
batch: ModelWorkerBatch,
target_hidden_states: torch.Tensor,
next_token_ids: torch.Tensor,
):
"""
Run draft model extend to correctly fill the KV cache.
Args:
batch: The batch to run.
target_hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Construct input_ids
pt = 0
for i, extend_len in enumerate(batch.extend_seq_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], next_token_ids[i].reshape(1))
)
pt += extend_len
# Construct spec_info
next_draft_input = EagleDraftInput(
hidden_states=target_hidden_states,
verified_id=next_token_ids,
new_seq_lens=batch.seq_lens,
allocate_lens=batch.seq_lens,
)
batch.spec_info = next_draft_input
# Run forward
forward_batch = ForwardBatch.init_new(batch, self.draft_model_runner)
logits_output, _ = self.draft_model_runner.forward(forward_batch)
# Update spec_info for the next draft step
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
probs, self.topk, dim=-1
)
next_draft_input.hidden_states = logits_output.hidden_states
return next_draft_input
def move_accepted_tokens_to_target_kvcache(
self,
batch: ModelWorkerBatch,
accept_index: torch.Tensor,
accept_length: torch.Tensor,
):
"""
Move accepted tokens to the target KV cache.
Args:
batch: The batch to run.
accept_index: The index of the accepted tokens.
accept_length: The length of the accepted tokens.
"""
bs = len(batch.seq_lens)
size = bs * self.speculative_num_draft_tokens
tgt_cache_loc = torch.zeros(
size,
dtype=torch.int64,
device=self.device,
)
accepted_out_cache_loc = torch.zeros(
size, dtype=torch.int64, device=self.device
)
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
self.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length,
tgt_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
fill_accepted_out_cache_loc[(size,)](
accept_index,
batch.out_cache_loc,
accepted_out_cache_loc,
next_power_of_2(size),
)
self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
tgt_cache_loc, accepted_out_cache_loc
)
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.error("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
def free_spec_dec_tokens_page_size_1(
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
req: Req,
allocate_len: int,
new_seq_len: int,
):
# FIXME(lsyin): move this function elsewhere
# free extra allocated tokens
if new_seq_len is None:
# True only for overlap eagle and the current batch is decode. This seq will be part of the decode, so the final iteration's allocation is not used (i.e. this case).
start_len = allocate_len - EagleDraftInput.ALLOC_LEN_PER_DECODE
else:
# True for 1) non-overlap; 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration, so start_lens is passed in.
start_len = new_seq_len
indices_to_free = req_to_token_pool.req_to_token[req.req_pool_idx][
start_len:allocate_len
]
token_to_kv_pool_allocator.free(indices_to_free)
......@@ -435,7 +435,7 @@ def select_top_k_tokens(
return input_ids, hidden_states, scores, tree_info
def _generate_simulated_accept_index(
def generate_simulated_accept_index(
accept_index,
predict,
accept_length,
......
......@@ -4,7 +4,7 @@ import copy
import dataclasses
import logging
from dataclasses import replace
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
import torch
......@@ -30,12 +30,12 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
from sglang.srt.speculative.eagle_info import EagleVerifyInput
_is_hip = is_hip()
......
......@@ -67,6 +67,7 @@ suites = {
TestFile("test_deterministic.py", 300),
TestFile("test_eagle_infer_a.py", 370),
TestFile("test_eagle_infer_b.py", 700),
TestFile("test_eagle_infer_beta.py", 300),
TestFile("test_ebnf_constrained.py", 108),
TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_fa3.py", 376),
......
......@@ -69,6 +69,7 @@ suites = {
TestFile("test_deterministic.py", 300),
TestFile("test_eagle_infer_a.py", 370),
TestFile("test_eagle_infer_b.py", 700),
TestFile("test_eagle_infer_beta.py", 300),
TestFile("test_ebnf_constrained.py", 108),
TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_fa3.py", 376),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment