# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention backend utils""" from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.utils import async_tensor_h2d, make_tensor_with_pad logger = init_logger(__name__) PAD_SLOT_ID = -1 # Switch to numpy implementation of compute_slot_mapping # if we have at least this many elements. Could be tuned further. _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 def is_block_tables_empty(block_tables: Union[None, Dict]): """ Check if block_tables is None or a dictionary with all None values. """ if block_tables is None: return True return (isinstance(block_tables, dict) and all(value is None for value in block_tables.values())) def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, context_len: int, sliding_window: int): """ Compute the start index of slot mapping. """ start_idx = 0 if is_prompt and sliding_window is not None: start_idx = max(0, query_len - sliding_window) return start_idx def _compute_slot_mapping_python(slot_mapping: List[int], block_table: List[int], range_start: int, range_end: int, block_size: int): for i in range(range_start, range_end): block_number = block_table[i // block_size] block_offset = i % block_size slot = block_number * block_size + block_offset slot_mapping.append(slot) def _compute_slot_mapping_numpy(slot_mapping: List[int], block_table: List[int], range_start: int, range_end: int, block_size: int): block_table_array = np.array(block_table) idx = np.arange(range_start, range_end) block_offset = idx % block_size idx //= block_size seq_slot_mapping_array = block_table_array[idx] seq_slot_mapping_array *= block_size seq_slot_mapping_array += block_offset slot_mapping.extend(seq_slot_mapping_array) def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], seq_id: int, seq_len: int, context_len: int, start_idx: int, block_size: int, block_tables: Dict[int, List[int]]): """ Compute slot mapping. """ if is_profile_run: # During memory profiling, the block tables are not # initialized yet. In this case, we just use a dummy # slot mapping. # In embeddings, the block tables are {seq_id: None}. slot_mapping.extend([PAD_SLOT_ID] * seq_len) return # Mask the [0, start_idx) tokens of the prompt with # PAD_SLOT_ID, where start_idx is max(0, seq_len - # sliding_window). For example, if the prompt len is 10, # sliding window is 8, and block size is 4, the first two # tokens are masked and the slot mapping will be # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. padding_mask_len = max(0, start_idx - context_len) slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) range_start = max(start_idx, context_len) range_end = seq_len numel = range_end - range_start block_table = block_tables[seq_id] # numpy implementation will be faster than python if we have # many elements, otherwise it will be slower. if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: _compute_slot_mapping_python(slot_mapping, block_table, range_start, range_end, block_size) else: _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, range_end, block_size) TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): _metadata_cls: Type[TAttentionMetadata] def __init__(self, input_builder): self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool): is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: assert query_len == 1, ( "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) # Compute block table. # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] if inter_data.prefix_cache_hit: block_table = block_tables[seq_id] elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): if curr_sliding_window_block == 0: block_table = block_tables[seq_id] else: block_table = block_tables[seq_id][ -curr_sliding_window_block:] self.block_tables.append(block_table) # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, context_len, self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. cuda_graph_pad_size: The padding size for cuda graph. -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens query_start_loc = list(accumulate(query_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) num_decode_tokens = batch_size # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table block_tables = torch.from_numpy(input_block_tables).to( device, non_blocking=True) else: block_tables = make_tensor_with_pad( self.block_tables, pad=0, dtype=torch.int, device=device, ) assert max_query_len > 0, "query_lens: {}".format(query_lens) assert device is not None context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, device, self.runner.pin_memory) seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, device, self.runner.pin_memory) return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc_tensor, seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) class CommonAttentionState(AttentionState): def __init__(self, runner): self.runner = runner self._is_graph_capturing = False @contextmanager def graph_capture(self, max_batch_size: int): self._is_graph_capturing = True self._graph_slot_mapping = torch.full((max_batch_size, ), PAD_SLOT_ID, dtype=torch.long, device=self.runner.device) self._graph_seq_lens = torch.ones(max_batch_size, dtype=torch.int32, device=self.runner.device) self._graph_block_tables = torch.from_numpy( self.runner.graph_block_tables).to(device=self.runner.device) yield self._is_graph_capturing = False del self._graph_slot_mapping del self._graph_seq_lens del self._graph_block_tables def graph_clone(self, batch_size: int) -> "CommonAttentionState": assert self._is_graph_capturing return self.__class__(self.runner) def graph_capture_get_metadata_for_batch( self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, max_decode_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.runner.max_model_len, query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, block_tables=self._graph_block_tables[:batch_size], use_cuda_graph=True, ) if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. assert self.runner.attn_backend.get_name() in \ ["XFORMERS", "FLASH_ATTN"], \ f"Expected attn_backend name to be either 'XFORMERS' or " \ f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) return attn_metadata def get_graph_input_buffers( self, attn_metadata, is_encoder_decoder_model: bool = False) -> Dict[str, Any]: input_buffers = { "slot_mapping": attn_metadata.slot_mapping, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. assert self.runner.attn_backend.get_name() in \ ["XFORMERS", "FLASH_ATTN"], \ f"Expected attn_backend name to be either 'XFORMERS' or " \ f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'" self._add_additional_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers def prepare_graph_input_buffers( self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False) -> None: input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. assert self.runner.attn_backend.get_name() in\ ["XFORMERS", "FLASH_ATTN"], \ f"Expected attn_backend name to be either 'XFORMERS' or "\ f"'FLASH_ATTN', but "\ f"got '{self.runner.attn_backend.get_name()}'" self._prepare_input_buffers_for_enc_dec_model( attn_metadata, input_buffers) def begin_forward(self, model_input) -> None: return def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, attn_metadata): """ Updates the attention metadata parameters for CUDA graph capture in an encoder-decoder model. This method modifies attention-related tensors and metadata required for CUDA graph capture in encoder-decoder models. Specifically, it updates the cross-attention and encoder sequence tensors in the AttentionMetadata object. """ # During decode phase the cross_slot_mapping will be empty. Hence set # an empty tensor for CUDA Graph capture. attn_metadata.cross_slot_mapping = torch.tensor( [], dtype=torch.int).cuda() attn_metadata.cross_block_tables = torch.full( (batch_size, self.runner.get_max_block_per_batch()), 1, dtype=torch.int).cuda() attn_metadata.encoder_seq_lens = torch.full((batch_size, ), 1, dtype=torch.int).cuda() attn_metadata.encoder_seq_lens_tensor = torch.full( (batch_size, ), 1, dtype=torch.int).cuda() attn_metadata.max_encoder_seq_len = self.runner.max_model_len attn_metadata.num_encoder_tokens = 0 def _add_additional_input_buffers_for_enc_dec_model( self, attn_metadata, input_buffers: Dict[str, Any]): """ Saves additional input buffers specific to the encoder-decoder model from the attention metadata. This method extracts and stores encoder-decoder related input buffers from the `attn_metadata` into the `input_buffers` dictionary. The buffers include encoder sequence lengths, cross-slot mappings, and cross-block tables, which are essential for the encoder-decoder model during CUDA graph replay. """ input_buffers["encoder_seq_lens_tensor"] = ( attn_metadata.decode_metadata.encoder_seq_lens_tensor) input_buffers["cross_slot_mapping"] = ( attn_metadata.decode_metadata.cross_slot_mapping) input_buffers["cross_block_tables"] = ( attn_metadata.decode_metadata.cross_block_tables) def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, input_buffers: Dict[str, Any]): """ Populates input buffers with data from the encoder-decoder model's attention metadata. This method fills the input buffers with encoder-decoder specific tensors. It copies data from the `attn_metadata` and keyword arguments (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. The copied data includes attention-related metadata as well as input IDs and positional information for the encoder. """ input_buffers["encoder_seq_lens_tensor"].copy_( attn_metadata.decode_metadata.encoder_seq_lens_tensor, non_blocking=True) input_buffers["cross_slot_mapping"].copy_( attn_metadata.decode_metadata.cross_slot_mapping, non_blocking=True) input_buffers["cross_block_tables"].copy_( attn_metadata.decode_metadata.cross_block_tables, non_blocking=True) def is_all_encoder_attn_metadata_set(attn_metadata): ''' All attention metadata required for encoder attention is set. ''' return ((attn_metadata.encoder_seq_lens is not None) and (attn_metadata.encoder_seq_lens_tensor is not None) and (attn_metadata.max_encoder_seq_len is not None)) def is_all_cross_attn_metadata_set(attn_metadata): ''' All attention metadata required for enc/dec cross-attention is set. Superset of encoder attention required metadata. ''' return (attn_metadata.is_all_encoder_attn_metadata_set and (attn_metadata.cross_slot_mapping is not None) and (attn_metadata.cross_block_tables is not None)) def get_seq_len_block_table_args( attn_metadata, is_prompt: bool, attn_type: str, ) -> tuple: ''' The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent on the type of attention operation. Decoder attn -> select entirely decoder self-attention-related fields Encoder/decoder cross-attn -> select encoder sequence lengths & cross-attn block-tables fields Encoder attn -> select encoder sequence lengths fields & no block tables Arguments: * attn_metadata: Attention metadata structure associated with attention op * is_prompt: True if prefill, False otherwise * attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention Returns: * Appropriate sequence-lengths tensor * Appropriate max sequence-length scalar * Appropriate block tables (or None) ''' if attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run if is_prompt: max_seq_len = attn_metadata.max_prefill_seq_len else: max_seq_len = attn_metadata.max_decode_seq_len return (attn_metadata.seq_lens_tensor, max_seq_len, attn_metadata.block_tables) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables return (attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, attn_metadata.cross_block_tables) elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention return (attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") def get_num_prefill_decode_query_kv_tokens( attn_metadata, attn_type: str, ) -> Tuple[int, int, int]: """ Calculate the number of prefill and decode tokens for query, key/value based on the attention metadata and the specified attention type. Args: attn_metadata (AttentionMetadata): Attention Metadata object. attn_type (AttentionType): The type of attention being used. Returns: Tuple[int, int, int]: A tuple containing three integers: - The number of prefill query tokens. - The number of prefill key/value tokens. - The number of decode query tokens. Raises: AssertionError: If the number of encoder tokens in `attn_metadata` is `None` when required for the calculations. """ num_prefill_query_tokens = 0 num_decode_query_tokens = 0 num_prefill_kv_tokens = 0 if attn_type == AttentionType.ENCODER: # Encoder attention is only invoked during prefill phase. # The same input servers a both query and key. assert attn_metadata.num_encoder_tokens is not None num_prefill_query_tokens = attn_metadata.num_encoder_tokens num_prefill_kv_tokens = attn_metadata.num_encoder_tokens num_decode_query_tokens = 0 elif attn_type == AttentionType.ENCODER_DECODER: assert attn_metadata.num_encoder_tokens is not None num_prefill_query_tokens = attn_metadata.num_prefill_tokens # The key is the encoder/cross-attention. num_prefill_kv_tokens = attn_metadata.num_encoder_tokens num_decode_query_tokens = attn_metadata.num_decode_tokens else: # attn_type == AttentionType.DECODER or # attn_type == AttentionType.ENCODER_ONLY num_prefill_query_tokens = attn_metadata.num_prefill_tokens num_prefill_kv_tokens = attn_metadata.num_prefill_tokens num_decode_query_tokens = attn_metadata.num_decode_tokens return (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) @dataclass class MLADims: q_lora_rank: Optional[int] kv_lora_rank: int qk_nope_head_dim: int qk_rope_head_dim: int v_head_dim: int def get_mla_dims(model_config: ModelConfig) -> MLADims: hf_text_config = model_config.hf_text_config return MLADims( q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), kv_lora_rank=hf_text_config.kv_lora_rank, qk_nope_head_dim=hf_text_config.qk_nope_head_dim, qk_rope_head_dim=hf_text_config.qk_rope_head_dim, v_head_dim=hf_text_config.v_head_dim, )