# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any import torch from typing_extensions import override from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.triton_utils import triton from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.spec_decode.llm_base_proposer import SpecDecodeBaseProposer from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel logger = init_logger(__name__) class DFlashProposer(SpecDecodeBaseProposer): def __init__( self, vllm_config: VllmConfig, device: torch.device, runner=None, ): assert vllm_config.speculative_config is not None assert vllm_config.speculative_config.method == "dflash" super().__init__( vllm_config=vllm_config, device=device, pass_hidden_states_to_model=True, runner=runner, ) # Only next_token_ids and mask tokens are query tokens, all other context is K/V self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens) # Positions covers both context states + query states self.max_positions = self.max_num_tokens + self.max_query_tokens # Separate context buffers to keep query buffer addresses stable for CUDA graphs self._context_slot_mapping_buffer = torch.zeros( self.max_num_tokens, dtype=torch.int64, device=device, ) self._slot_mapping_buffer = torch.zeros( self.max_query_tokens, dtype=torch.int64, device=device, ) self._context_positions_buffer = torch.zeros( self.max_num_tokens, dtype=torch.int64, device=device, ) self.positions = torch.zeros( self.max_query_tokens, dtype=torch.int64, device=device, ) self.arange = torch.arange( self.max_positions + 1, device=device, dtype=torch.int32 ) # For DFlash we use the input embeddings to embed the mask token self.parallel_drafting_hidden_state_tensor = None @override def _raise_if_multimodal(self): # Override to allow multimodal inputs since DFlash supports Qwen3.5 models # Support for multimodal inputs has not been tested. pass @override def set_inputs_first_pass( self, target_token_ids: torch.Tensor, next_token_ids: torch.Tensor, target_positions: torch.Tensor, target_hidden_states: torch.Tensor, token_indices_to_sample: torch.Tensor | None, cad: CommonAttentionMetadata, num_rejected_tokens_gpu: torch.Tensor | None, ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: # DFlash cross-attention: context K/V from target hidden states, # Q from query embeddings (bonus + mask tokens). batch_size = cad.batch_size() num_context = target_token_ids.shape[0] num_query_per_req = 1 + self.num_speculative_tokens num_query_total = batch_size * num_query_per_req # Store for build_model_inputs_first_pass to use self._dflash_num_context = num_context # We don't need to copy into a buffer here since the context preprocessing # does not run in a CUDA graph self._dflash_hidden_states = target_hidden_states token_indices_to_sample = torch.empty( batch_size * self.num_speculative_tokens, dtype=torch.int32, device=self.device, ) # Launch fused triton kernel for input_ids, positions, slot_mapping, # and token_indices_to_sample max_ctx_per_req = cad.max_query_len max_tokens_per_req = max_ctx_per_req + num_query_per_req BLOCK_SIZE = min(256, triton.next_power_of_2(max_tokens_per_req)) num_blocks = triton.cdiv(max_tokens_per_req, BLOCK_SIZE) grid = (batch_size, num_blocks) has_num_rejected = num_rejected_tokens_gpu is not None copy_and_expand_dflash_inputs_kernel[grid]( # Inputs next_token_ids_ptr=next_token_ids, target_positions_ptr=target_positions, # Outputs out_input_ids_ptr=self.input_ids, out_context_positions_ptr=self._context_positions_buffer, out_query_positions_ptr=self.positions, out_context_slot_mapping_ptr=self._context_slot_mapping_buffer, out_query_slot_mapping_ptr=self._slot_mapping_buffer, out_token_indices_ptr=token_indices_to_sample, # Block table block_table_ptr=cad.block_table_tensor, block_table_stride=cad.block_table_tensor.stride(0), # Metadata query_start_loc_ptr=cad.query_start_loc, num_rejected_tokens_ptr=( num_rejected_tokens_gpu if has_num_rejected else 0 ), # Scalars parallel_drafting_token_id=self.parallel_drafting_token_id, block_size=self.block_size, num_query_per_req=num_query_per_req, num_speculative_tokens=self.num_speculative_tokens, total_input_tokens=num_context, BLOCK_SIZE=BLOCK_SIZE, HAS_NUM_REJECTED=has_num_rejected, ) query_slot_mapping = self._slot_mapping_buffer[:num_query_total] new_query_start_loc = self.arange[: batch_size + 1] * num_query_per_req # In padded mode, cad.seq_lens includes rejected tokens. Subtract # them so attention only sees the valid prefix of context states. effective_seq_lens = cad.seq_lens if has_num_rejected: effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, seq_lens=effective_seq_lens + num_query_per_req, query_start_loc_cpu=( torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone() * num_query_per_req ), _seq_lens_cpu=None, _num_computed_tokens_cpu=None, num_reqs=cad.num_reqs, num_actual_tokens=num_query_total, max_query_len=num_query_per_req, max_seq_len=cad.max_seq_len + num_query_per_req, block_table_tensor=cad.block_table_tensor, slot_mapping=query_slot_mapping, causal=False, # Non-causal attention is required for DFlash ) return num_query_total, token_indices_to_sample, new_cad @override @torch.inference_mode() def dummy_run( self, num_tokens: int, use_cudagraphs: bool = True, is_graph_capturing: bool = False, slot_mappings: dict[str, torch.Tensor] | None = None, ) -> None: """ Key differences to default dummy_run: - Only one forward pass due to parallel drafting - DFlash uses context states as unpadded metadata, so hidden_states will use the unpadded num_tokens instead of num_input_tokens - max_query_tokens is quite small, DFlash only sees spec tokens as queries - Multimodal inputs are not currently supported """ num_query_tokens = min(num_tokens, self.max_query_tokens) cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( self._determine_batch_execution_and_padding( num_query_tokens, use_cudagraphs=use_cudagraphs ) ) # Slot mapping sized to num_input_tokens (query only), matching # the K/V tensor size from the model forward. Context KVs are # pre-inserted separately and don't flow through the model. if ( self._draft_attn_layer_names and slot_mappings is not None and next(iter(self._draft_attn_layer_names)) in slot_mappings ): slot_mapping_dict = self._get_slot_mapping(num_input_tokens) else: slot_mapping_dict = slot_mappings or {} # Context and query positions use separate buffers; no copy needed. context_positions = self._context_positions_buffer[:num_tokens] # Context states will be passed directly to the precomputation without # going through the buffer, since no CUDA graph is used for the precomputation. # For the dummy run, we use the dummy buffer. context_states = self.hidden_states[:num_tokens] # Run the KV projection (GEMM + norms + RoPE) for memory profiling, self.model.precompute_and_store_context_kv(context_states, context_positions) with set_forward_context( None, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, slot_mapping=slot_mapping_dict, ): self.model( input_ids=self.input_ids[:num_input_tokens], positions=self._get_positions(num_input_tokens), inputs_embeds=None, ) @override def build_model_inputs_first_pass( self, num_tokens: int, num_input_tokens: int, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None, ) -> tuple[dict[str, Any], int]: # Context and query positions/slots were written to separate # buffers by the kernel — no copy needed. num_context = self._dflash_num_context # Pre-insert context KVs directly into cache self.model.precompute_and_store_context_kv( self._dflash_hidden_states, # Shape is already [num_context, hidden_size] self._context_positions_buffer[:num_context], self._context_slot_mapping_buffer[:num_context], ) return ( dict( input_ids=self.input_ids[:num_input_tokens], positions=self._get_positions(num_input_tokens), inputs_embeds=None, ), num_input_tokens, ) @override def build_per_group_and_layer_attn_metadata( self, cad: CommonAttentionMetadata, draft_index: int = 0 ) -> tuple[list[object], dict[str, object]]: per_group, per_layer = super().build_per_group_and_layer_attn_metadata( cad, draft_index ) for layer_name, attn_metadata in per_layer.items(): assert getattr(attn_metadata, "causal", None) is False, ( f"Attention metadata for layer {layer_name} does not have" " non-causal support, which is required for DFlash." " Consider using a different attention backend, such as FlashAttention." ) return per_group, per_layer @override def _get_eagle3_use_aux_hidden_state_from_config(self): use_aux_hidden_state = True dflash_config = getattr( self.draft_model_config.hf_config, "dflash_config", None ) if dflash_config is not None: use_aux_hidden_state = dflash_config.get("use_aux_hidden_state", True) return use_aux_hidden_state