import logging import os import time from contextlib import contextmanager from typing import List, Optional, Tuple import torch from huggingface_hub import snapshot_download from sglang.srt.distributed import ( GroupCoordinator, get_tp_group, patch_tensor_parallel_group, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.mm_utils import embed_mm_inputs from sglang.srt.managers.schedule_batch import ( ScheduleBatch, get_last_loc, global_server_args_dict, ) from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, ) from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, ) from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import ( EAGLEDraftExtendCudaGraphRunner, ) from sglang.srt.speculative.eagle_utils import ( EagleDraftInput, EagleVerifyInput, EagleVerifyOutput, assign_draft_cache_locs, fast_topk, generate_token_bitmask, select_top_k_tokens, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( empty_context, get_available_gpu_memory, get_bool_env_var, is_cuda, next_power_of_2, ) if is_cuda(): from sgl_kernel import segment_packbits logger = logging.getLogger(__name__) RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") @contextmanager def draft_tp_context(tp_group: GroupCoordinator): # Draft model doesn't use dp and has its own tp group. # We disable mscclpp now because it doesn't support 2 comm groups. with patch_tensor_parallel_group(tp_group): yield class EAGLEWorker(TpModelWorker): def __init__( self, server_args: ServerArgs, gpu_id: int, tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, nccl_port: int, target_worker: TpModelWorker, ): # Parse arguments self.server_args = server_args self.topk = server_args.speculative_eagle_topk self.speculative_num_steps = server_args.speculative_num_steps self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens self.enable_nan_detection = server_args.enable_nan_detection self.gpu_id = gpu_id self.device = server_args.device self.target_worker = target_worker self.page_size = server_args.page_size self.speculative_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) self.padded_static_len = -1 # Override the context length of the draft model to be the same as the target model. server_args.context_length = target_worker.model_runner.model_config.context_len # Do not capture cuda graph in `super().__init__()` # It will be captured later. backup_disable_cuda_graph = server_args.disable_cuda_graph server_args.disable_cuda_graph = True # Share the allocator with a target worker. # Draft and target worker own their own KV cache pools. self.req_to_token_pool, self.token_to_kv_pool_allocator = ( target_worker.get_memory_pool() ) # Load hot token ids if self.speculative_algorithm.is_eagle3(): if server_args.speculative_token_map is not None: logger.warning( "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map." ) self.hot_token_id = None elif server_args.speculative_token_map is not None: self.hot_token_id = load_token_map(server_args.speculative_token_map) server_args.json_model_override_args = ( f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' ) else: self.hot_token_id = None # Init draft worker with empty_context(): super().__init__( server_args=server_args, gpu_id=gpu_id, tp_rank=tp_rank, pp_rank=0, # FIXME dp_rank=dp_rank, moe_ep_rank=moe_ep_rank, nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, ) embed, head = self.target_worker.model_runner.model.get_embed_and_head() if self.speculative_algorithm.is_eagle3(): # most cases EAGLE3 models don't share lm_head # but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares if ( hasattr(self.draft_model_runner.model, "load_lm_head_from_target") and self.draft_model_runner.model.load_lm_head_from_target ): self.draft_model_runner.model.set_embed_and_head(embed, head) else: self.draft_model_runner.model.set_embed(embed) # grab hot token ids if self.draft_model_runner.model.hot_token_id is not None: self.hot_token_id = self.draft_model_runner.model.hot_token_id.to( embed.device ) else: if self.hot_token_id is not None: head = head.clone() self.hot_token_id = self.hot_token_id.to(head.device) head.data = head.data[self.hot_token_id] # Share the embedding and lm_head self.draft_model_runner.model.set_embed_and_head(embed, head) # Init attention backend and cuda graphs self.draft_model_runner.server_args.disable_cuda_graph = ( backup_disable_cuda_graph ) self.draft_tp_context = ( draft_tp_context if server_args.enable_dp_attention else empty_context ) with self.draft_tp_context(self.draft_model_runner.tp_group): self.init_attention_backend() self.init_cuda_graphs() # Some dummy tensors self.num_new_pages_per_topk = torch.empty( (), dtype=torch.int64, device=self.device ) self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device) def init_attention_backend(self): # Create multi-step attn backends and cuda graph runners self.has_prefill_wrapper_verify = False self.draft_extend_attn_backend = None # Initialize decode attention backend self.draft_attn_backend = self._create_decode_backend() # Initialize prefill attention backend self.draft_extend_attn_backend = self._create_draft_extend_backend() self.draft_model_runner.draft_attn_backend = self.draft_attn_backend def _create_backend( self, backend_name: str, backend_map: dict, error_template: str ): backend_type = getattr(self.server_args, backend_name) if backend_type is None: backend_type = self.server_args.attention_backend if backend_type not in backend_map: raise ValueError(error_template.format(backend_type=backend_type)) return backend_map[backend_type]() def _create_decode_backend(self): backend_map = { "flashinfer": self._create_flashinfer_decode_backend, "triton": self._create_triton_decode_backend, "aiter": self._create_aiter_decode_backend, "fa3": self._create_fa3_decode_backend, "flashmla": self._create_flashmla_decode_backend, "trtllm_mha": self._create_trtllm_mha_decode_backend, "trtllm_mla": self._create_trtllm_mla_decode_backend, } return self._create_backend( "decode_attention_backend", backend_map, "EAGLE is not supported in decode attention backend {backend_type}", ) def _create_draft_extend_backend(self): backend_map = { "flashinfer": self._create_flashinfer_prefill_backend, "triton": self._create_triton_prefill_backend, "aiter": self._create_aiter_prefill_backend, "fa3": self._create_fa3_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend, } return self._create_backend( "prefill_attention_backend", backend_map, "EAGLE is not supported in prefill attention backend {backend_type}", ) def _create_flashinfer_decode_backend(self): if not global_server_args_dict["use_mla_backend"]: from sglang.srt.layers.attention.flashinfer_backend import ( FlashInferMultiStepDraftBackend, ) self.has_prefill_wrapper_verify = True return FlashInferMultiStepDraftBackend( self.draft_model_runner, self.topk, self.speculative_num_steps ) else: from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAMultiStepDraftBackend, ) self.has_prefill_wrapper_verify = True return FlashInferMLAMultiStepDraftBackend( self.draft_model_runner, self.topk, self.speculative_num_steps ) def _create_triton_decode_backend(self): from sglang.srt.layers.attention.triton_backend import ( TritonMultiStepDraftBackend, ) return TritonMultiStepDraftBackend( self.draft_model_runner, self.topk, self.speculative_num_steps ) def _create_aiter_decode_backend(self): from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend return AiterMultiStepDraftBackend( self.draft_model_runner, self.topk, self.speculative_num_steps ) def _create_fa3_decode_backend(self): from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionMultiStepBackend, ) return FlashAttentionMultiStepBackend( self.draft_model_runner, self.topk, self.speculative_num_steps ) def _create_flashmla_decode_backend(self): from sglang.srt.layers.attention.flashmla_backend import ( FlashMLAMultiStepDraftBackend, ) return FlashMLAMultiStepDraftBackend( self.draft_model_runner, self.topk, self.speculative_num_steps ) def _create_trtllm_mha_decode_backend(self): from sglang.srt.layers.attention.trtllm_mha_backend import ( TRTLLMHAAttnMultiStepDraftBackend, ) self.has_prefill_wrapper_verify = True return TRTLLMHAAttnMultiStepDraftBackend( self.draft_model_runner, self.topk, self.speculative_num_steps ) def _create_trtllm_mla_decode_backend(self): if not global_server_args_dict["use_mla_backend"]: raise ValueError( "trtllm_mla backend requires MLA model (use_mla_backend=True)." ) from sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLAMultiStepDraftBackend, ) self.has_prefill_wrapper_verify = True return TRTLLMMLAMultiStepDraftBackend( self.draft_model_runner, self.topk, self.speculative_num_steps ) def _create_flashinfer_prefill_backend(self): if not global_server_args_dict["use_mla_backend"]: from sglang.srt.layers.attention.flashinfer_backend import ( FlashInferAttnBackend, ) return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False) else: from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAAttnBackend, ) return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False) def _create_triton_prefill_backend(self): from sglang.srt.layers.attention.triton_backend import TritonAttnBackend return TritonAttnBackend(self.draft_model_runner, skip_prefill=False) def _create_aiter_prefill_backend(self): from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend return AiterAttnBackend(self.draft_model_runner, skip_prefill=False) def _create_fa3_prefill_backend(self): from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionBackend, ) return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False) def _create_trtllm_mha_prefill_backend(self): from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False) def _create_trtllm_mla_prefill_backend(self): if not global_server_args_dict["use_mla_backend"]: raise ValueError( "trtllm_mla backend requires MLA model (use_mla_backend=True)." ) from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False) def init_cuda_graphs(self): """Capture cuda graphs.""" self.cuda_graph_runner = None self.cuda_graph_runner_for_draft_extend = None if self.server_args.disable_cuda_graph: return # Capture draft tic = time.perf_counter() before_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) # Capture extend if self.draft_extend_attn_backend: tic = time.perf_counter() before_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner( self ) after_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) @property def draft_model_runner(self): return self.model_runner def forward_batch_speculative_generation( self, batch: ScheduleBatch ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]: """Run speculative decoding forward. NOTE: Many states of batch is modified as you go through. It is not guaranteed that the final output batch have the same state as the input. Args: batch: The batch to run forward. The state of the batch is modified as it runs. Returns: A tuple of the final logit output of the target model, next tokens accepted, the batch id (used for overlap schedule), and number of accepted tokens. """ if batch.forward_mode.is_extend() or batch.is_extend_in_batch: logits_output, next_token_ids, bid, seq_lens_cpu = ( self.forward_target_extend(batch) ) with self.draft_tp_context(self.draft_model_runner.tp_group): self.forward_draft_extend( batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu ) return logits_output, next_token_ids, bid, 0, False else: with self.draft_tp_context(self.draft_model_runner.tp_group): spec_info = self.draft(batch) logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( self.verify(batch, spec_info) ) with self.draft_tp_context(self.draft_model_runner.tp_group): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. if ( self.server_args.enable_dp_attention or batch.spec_info.verified_id.shape[0] > 0 ): # decode is not finished self.forward_draft_extend_after_decode(batch) return ( logits_output, verify_output.verified_id, model_worker_batch.bid, sum(verify_output.accept_length_per_req_cpu), can_run_cuda_graph, ) def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): local_need_forward = batch.spec_info.verified_id.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward global_need_forward = torch.tensor( [ (local_need_forward), ], dtype=torch.int64, ) torch.distributed.all_reduce( global_need_forward, group=get_tp_group().cpu_group ) global_need_forward_cnt = global_need_forward[0].item() need_forward = global_need_forward_cnt > 0 return need_forward def forward_target_extend( self, batch: ScheduleBatch ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]: """Run the target extend. Args: batch: The batch to run. States could be modified. Returns: logits_output: The output of logits. It will contain the full hidden states. next_token_ids: Next token ids generated. bid: The model batch ID. Used for overlap schedule. """ # Forward with the target model and get hidden states. # We need the full hidden states to prefill the KV cache of the draft model. model_worker_batch = batch.get_model_worker_batch() model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( model_worker_batch ) return ( logits_output, next_token_ids, model_worker_batch.bid, model_worker_batch.seq_lens_cpu, ) def _draft_preprocess_decode(self, batch: ScheduleBatch): # Parse args num_seqs = batch.batch_size() spec_info = batch.spec_info # Accumulate penalty if batch.sampling_info.penalizer_orchestrator.is_required: # This is a relaxed version of penalties for speculative decoding. batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( spec_info.verified_id.to(torch.int64) ) # Allocate cache locations # Layout of the out_cache_loc # [ topk 0 ] [ topk 1 ] # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2] if self.page_size == 1: out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( num_seqs * self.speculative_num_steps * self.topk, backup_state=True ) else: if self.topk == 1: prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1( batch.req_to_token_pool.req_to_token, batch.req_pool_indices, batch.seq_lens, self.speculative_num_steps, ) extend_num_tokens = num_seqs * self.speculative_num_steps else: # In this case, the last partial page needs to be duplicated. # KV cache layout in batch.req_to_token_pool.req_to_token: # # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. | # prefix top-k = 0 tok-k = 1 top-k = 2 # # "-" means prefix tokens # "x" means speculative draft tokens # "." means padded tokens # TODO(lmzheng): The current implementation is still a fake support # for page size > 1. In the `assign_draft_cache_locs` below, # we directly move the indices instead of the real kv cache. # This only works when the kernel backend runs with page size = 1. # If the kernel backend runs with page size > 1, we need to # duplicate the real KV cache. The overhead of duplicating KV # cache seems okay because the draft KV cache only has one layer. # see a related copy operation in MHATokenToKVPool::move_kv_cache. ( prefix_lens, seq_lens, last_loc, self.num_new_pages_per_topk, self.extend_lens, ) = get_last_loc_large_page_size_large_top_k( batch.req_to_token_pool.req_to_token, batch.req_pool_indices, batch.seq_lens, self.speculative_num_steps, self.topk, self.page_size, ) # TODO(lmzheng): remove this device sync extend_num_tokens = torch.sum(self.extend_lens).item() out_cache_loc, token_to_kv_pool_state_backup = ( batch.alloc_paged_token_slots_extend( prefix_lens, seq_lens, last_loc, extend_num_tokens, backup_state=True, ) ) assign_draft_cache_locs[(num_seqs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, self.extend_lens, self.num_new_pages_per_topk, out_cache_loc, batch.req_to_token_pool.req_to_token.shape[1], self.topk, self.speculative_num_steps, self.page_size, next_power_of_2(num_seqs), next_power_of_2(self.speculative_num_steps), ) if self.page_size > 1 and self.topk > 1: # Remove padded slots out_cache_loc = out_cache_loc[ : num_seqs * self.topk * self.speculative_num_steps ] batch.out_cache_loc = out_cache_loc batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.return_hidden_states = False spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup) def _draft_preprocess_idle(self, batch: ScheduleBatch): batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, hidden_size=self.model_config.hidden_size, dtype=self.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) def draft(self, batch: ScheduleBatch): # Parse args if batch.forward_mode.is_idle(): self._draft_preprocess_idle(batch) else: self._draft_preprocess_decode(batch) spec_info = batch.spec_info assert isinstance(spec_info, EagleDraftInput) spec_info.capture_hidden_mode = CaptureHiddenMode.LAST spec_info.num_tokens_per_batch = self.topk spec_info.num_tokens_for_logprob_per_batch = self.topk batch.return_hidden_states = False # Get forward batch model_worker_batch = batch.get_model_worker_batch() assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run( forward_batch ) if can_cuda_graph: score_list, token_list, parents_list = self.cuda_graph_runner.replay( forward_batch ) else: forward_batch.can_run_dp_cuda_graph = False if not forward_batch.forward_mode.is_idle(): # Initialize attention backend self.draft_attn_backend.init_forward_metadata(forward_batch) # Run forward steps score_list, token_list, parents_list = self.draft_forward(forward_batch) if batch.forward_mode.is_idle(): return EagleVerifyInput.create_idle_input( self.topk, self.speculative_num_steps, self.speculative_num_draft_tokens, ) ( tree_mask, position, retrive_index, retrive_next_token, retrive_next_sibling, draft_tokens, ) = build_tree_kernel_efficient( spec_info.verified_id, score_list, token_list, parents_list, batch.seq_lens, batch.seq_lens_sum, self.topk, self.speculative_num_steps, self.speculative_num_draft_tokens, ) return EagleVerifyInput( draft_token=draft_tokens, custom_mask=tree_mask, positions=position, retrive_index=retrive_index, retrive_next_token=retrive_next_token, retrive_next_sibling=retrive_next_sibling, retrive_cum_len=None, spec_steps=self.speculative_num_steps, topk=self.topk, draft_token_num=self.server_args.speculative_num_draft_tokens, capture_hidden_mode=CaptureHiddenMode.FULL, seq_lens_sum=forward_batch.seq_lens_sum, seq_lens_cpu=forward_batch.seq_lens_cpu, ) def draft_forward(self, forward_batch: ForwardBatch): # Parse args spec_info = forward_batch.spec_info assert isinstance(spec_info, EagleDraftInput) out_cache_loc = forward_batch.out_cache_loc topk_p, topk_index, hidden_states = ( spec_info.topk_p, spec_info.topk_index, spec_info.hidden_states, ) if self.hot_token_id is not None: topk_index = self.hot_token_id[topk_index] out_cache_loc = out_cache_loc.reshape( forward_batch.batch_size, self.topk, self.speculative_num_steps ) out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape( self.speculative_num_steps, -1 ) # Return values score_list: List[torch.Tensor] = [] token_list: List[torch.Tensor] = [] parents_list: List[torch.Tensor] = [] # Forward multiple steps scores = None for i in range(self.speculative_num_steps): input_ids, hidden_states, scores, tree_info = select_top_k_tokens( i, topk_p, topk_index, hidden_states, scores, self.topk ) score_list.append(tree_info[0]) token_list.append(tree_info[1]) parents_list.append(tree_info[2]) # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here if i == self.speculative_num_steps - 1: break # Set inputs forward_batch.input_ids = input_ids # This is a temporary fix for the case that the user is using standalone # speculative decoding and the draft model architecture is gpt-oss. gpt-oss # rope kernel needs cache_loc to be contiguous. if ( self.server_args.speculative_algorithm == "STANDALONE" and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM" ): out_cache_loc = out_cache_loc.contiguous() forward_batch.out_cache_loc = out_cache_loc[i] forward_batch.positions.add_(1) forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] spec_info.hidden_states = hidden_states # Run forward logits_output, _ = self.draft_model_runner.forward( forward_batch, skip_attn_backend_init=True ) self._detect_nan_if_needed(logits_output) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) if self.hot_token_id is not None: topk_index = self.hot_token_id[topk_index] hidden_states = logits_output.hidden_states return score_list, token_list, parents_list def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): spec_info.prepare_for_verify(batch, self.page_size) batch.return_hidden_states = False batch.forward_mode = ( ForwardMode.TARGET_VERIFY if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu ) assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode if batch.has_grammar: retrieve_next_token_cpu = spec_info.retrive_next_token.cpu() retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu() draft_tokens_cpu = spec_info.draft_token.view( spec_info.retrive_next_token.shape ).cpu() # Forward logits_output, _, can_run_cuda_graph = ( self.target_worker.forward_batch_generation( model_worker_batch, skip_sample=True ) ) vocab_mask = None if batch.has_grammar: # Generate the logit mask for structured output. # Overlap the CPU operations for bitmask generation with the forward pass. vocab_mask = generate_token_bitmask( batch.reqs, spec_info, retrieve_next_token_cpu, retrieve_next_sibling_cpu, draft_tokens_cpu, batch.sampling_info.vocab_size, ) if vocab_mask is not None: assert spec_info.grammar is not None vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device) # NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage # and will be applied to produce wrong results batch.sampling_info.vocab_mask = None self._detect_nan_if_needed(logits_output) spec_info.hidden_states = logits_output.hidden_states res: EagleVerifyOutput = spec_info.verify( batch, logits_output, self.token_to_kv_pool_allocator, self.page_size, vocab_mask, ) # Post process based on verified outputs. # Pick indices that we care (accepted) logits_output.next_token_logits = logits_output.next_token_logits[ res.accepted_indices ] logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] if batch.return_logprob: self.add_logprob_values(batch, res, logits_output) # Prepare the batch for the next draft forwards. batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) batch.spec_info = res.draft_input return logits_output, res, model_worker_batch, can_run_cuda_graph def add_logprob_values( self, batch: ScheduleBatch, res: EagleVerifyOutput, logits_output: LogitsProcessorOutput, ): # Extract args logits_output = res.logits_output top_logprobs_nums = batch.top_logprobs_nums token_ids_logprobs = batch.token_ids_logprobs accepted_indices = res.accepted_indices assert len(accepted_indices) == len(logits_output.next_token_logits) temperatures = batch.sampling_info.temperatures num_draft_tokens = batch.spec_info.draft_token_num # acceptance indices are the indices in a "flattened" batch. # dividing it to num_draft_tokens will yield the actual batch index. temperatures = temperatures[accepted_indices // num_draft_tokens] if RETURN_ORIGINAL_LOGPROB: logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logits, dim=-1 ) else: logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logits / temperatures, dim=-1 ) batch_next_token_ids = res.verified_id num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] # We should repeat top_logprobs_nums to match num_tokens_per_req. top_logprobs_nums_repeat_interleaved = [] token_ids_logprobs_repeat_interleaved = [] for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req): top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens) for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req): token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens) # Extract logprobs if any(x > 0 for x in top_logprobs_nums): ( logits_output.next_token_top_logprobs_val, logits_output.next_token_top_logprobs_idx, ) = get_top_logprobs( logprobs, top_logprobs_nums_repeat_interleaved, ) if any(x is not None for x in token_ids_logprobs): ( logits_output.next_token_token_ids_logprobs_val, logits_output.next_token_token_ids_logprobs_idx, ) = get_token_ids_logprobs( logprobs, token_ids_logprobs_repeat_interleaved, ) logits_output.next_token_logprobs = logprobs[ torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device), batch_next_token_ids, ] # Add output logprobs to the request pt = 0 next_token_logprobs = logits_output.next_token_logprobs.tolist() verified_ids = batch_next_token_ids.tolist() for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True): for _ in range(num_tokens): if req.return_logprob: req.output_token_logprobs_val.append(next_token_logprobs[pt]) req.output_token_logprobs_idx.append(verified_ids[pt]) if req.top_logprobs_num > 0: req.output_top_logprobs_val.append( res.logits_output.next_token_top_logprobs_val[pt] ) req.output_top_logprobs_idx.append( res.logits_output.next_token_top_logprobs_idx[pt] ) pt += 1 def forward_draft_extend( self, batch: ScheduleBatch, hidden_states: torch.Tensor, next_token_ids: torch.Tensor, seq_lens_cpu: Optional[torch.Tensor], ): """Run draft model extend. This API modifies the states of the batch. Args: batch: The batch to run. hidden_states: Hidden states from the target model forward next_token_ids: Next token ids generated from the target forward. """ batch.spec_info = EagleDraftInput( hidden_states=hidden_states, verified_id=next_token_ids, num_tokens_per_batch=1, num_tokens_for_logprob_per_batch=1, ) batch.return_hidden_states = False batch.spec_info.prepare_for_extend(batch) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=seq_lens_cpu ) forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) forward_batch.return_logprob = False logits_output, _ = self.draft_model_runner.forward(forward_batch) self._detect_nan_if_needed(logits_output) assert isinstance(forward_batch.spec_info, EagleDraftInput) assert forward_batch.spec_info is batch.spec_info self.capture_for_decode(logits_output, forward_batch.spec_info) has_finished, unfinished_req_index = False, [] for i, req in enumerate(batch.reqs): if req.finished(): has_finished = True else: unfinished_req_index.append(i) if has_finished: unfinished_index_device = torch.tensor( unfinished_req_index, dtype=torch.int64, device=batch.spec_info.topk_p.device, ) batch.spec_info.filter_batch( unfinished_index_device, has_been_filtered=False ) def forward_draft_extend_after_decode(self, batch: ScheduleBatch): assert isinstance(batch.spec_info, EagleDraftInput) # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() req_pool_indices_backup = batch.req_pool_indices accept_length_backup = batch.spec_info.accept_length return_logprob_backup = batch.return_logprob input_is_idle = batch.forward_mode.is_idle() if not input_is_idle and batch.spec_info.verified_id.numel() == 0: batch = batch.copy() batch.prepare_for_idle() hidden_size = ( self.model_config.hidden_size * 3 if self.speculative_algorithm.is_eagle3() else self.model_config.hidden_size ) batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, hidden_size=hidden_size, dtype=self.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1 batch.spec_info.num_tokens_for_logprob_per_batch = 1 batch.spec_info.prepare_extend_after_decode( batch, self.speculative_num_steps, ) batch.forward_mode = ( ForwardMode.DRAFT_EXTEND if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) if forward_batch.seq_lens_cpu is not None: forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item() else: forward_batch.seq_lens_sum = batch.seq_lens.sum().item() # Run can_cuda_graph = ( self.cuda_graph_runner_for_draft_extend and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) ) if can_cuda_graph: logits_output = self.cuda_graph_runner_for_draft_extend.replay( forward_batch ) forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = ( logits_output.topk_p, logits_output.topk_index, ) forward_batch.spec_info.hidden_states = logits_output.hidden_states else: forward_batch.can_run_dp_cuda_graph = False if not forward_batch.forward_mode.is_idle(): self.draft_model_runner.attn_backend.init_forward_metadata( forward_batch ) logits_output, _ = self.draft_model_runner.forward( forward_batch, skip_attn_backend_init=True ) self.capture_for_decode(logits_output, forward_batch.spec_info) self._detect_nan_if_needed(logits_output) # Restore backup. # This is because `seq_lens` can be modified in `prepare_extend_after_decode` batch.forward_mode = ( ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ) batch.seq_lens = seq_lens_backup batch.req_pool_indices = req_pool_indices_backup batch.spec_info.accept_length = accept_length_backup batch.return_logprob = return_logprob_backup def capture_for_decode( self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput ): probs = torch.softmax(logits_output.next_token_logits, dim=-1) draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1) draft_input.hidden_states = logits_output.hidden_states def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput): if self.enable_nan_detection: logits = logits_output.next_token_logits if torch.any(torch.isnan(logits)): logger.error("Detected errors during sampling! NaN in the logits.") raise ValueError("Detected errors during sampling! NaN in the logits.") def load_token_map(token_map_path: str) -> List[int]: if not os.path.exists(token_map_path): cache_dir = snapshot_download( os.path.dirname(token_map_path), ignore_patterns=["*.bin", "*.safetensors"], ) token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) hot_token_id = torch.load(token_map_path, weights_only=True) return torch.tensor(hot_token_id, dtype=torch.int64) @torch.compile(dynamic=True) def get_last_loc_large_page_size_top_k_1( req_to_token: torch.Tensor, req_pool_indices: torch.Tensor, seq_lens, speculative_num_steps: int, ): prefix_lens = seq_lens seq_lens = prefix_lens + speculative_num_steps last_loc = get_last_loc( req_to_token, req_pool_indices, prefix_lens, ) return prefix_lens, seq_lens, last_loc # Disable torch.compile for this function because it will be # even slower. # @torch.compile(dynamic=True) def get_last_loc_large_page_size_large_top_k( req_to_token: torch.Tensor, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, speculative_num_steps: int, topk: int, page_size: int, ): prefix_lens = seq_lens last_page_lens = prefix_lens % page_size num_new_pages_per_topk = ( last_page_lens + speculative_num_steps + page_size - 1 ) // page_size seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * ( page_size * topk ) extend_lens = seq_lens - prefix_lens last_loc = get_last_loc( req_to_token, req_pool_indices, prefix_lens, ) return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens