# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py import torch import torch.nn.functional as F from vllm_omni.platforms import current_omni_platform # Flash Attention function detection with fallback chain flash_attn_func = None flash_attn_varlen_func = None if current_omni_platform.is_rocm(): # ROCm: try Aiter first try: from vllm._aiter_ops import is_aiter_found_and_supported if is_aiter_found_and_supported(): from aiter import flash_attn_func, flash_attn_varlen_func # noqa: F401 except (ImportError, ModuleNotFoundError): pass else: # CUDA: try FA3 -> FA2 fallback chain # Try FA3 from fa3-fwd PyPI package try: from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 except (ImportError, ModuleNotFoundError): pass # Fallback: Try FA3 from flash-attention source build if flash_attn_func is None: try: from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 except (ImportError, ModuleNotFoundError): pass # Fallback: Try FA2 from flash-attn package (try multiple import paths) if flash_attn_func is None: try: from flash_attn import flash_attn_func, flash_attn_varlen_func # noqa: F401 except (ImportError, ModuleNotFoundError): pass if flash_attn_func is None: try: from flash_attn.flash_attn_interface import ( # noqa: F401 flash_attn_func, flash_attn_varlen_func, ) except (ImportError, ModuleNotFoundError): pass # If no FA backend available, SDPA backend will be selected at the platform level # flash_attn_func and flash_attn_varlen_func will be None HAS_FLASH_ATTN = flash_attn_func is not None def _index_first_axis(tensor, indices): """ A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, after flattening the first two dimensions of the tensor. This is functionally equivalent to FA2's `index_first_axis` and replaces the need to import it. """ # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first # two dimensions to get (total_tokens, ...) before indexing. reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) return reshaped_tensor[indices] def _unpad_input(hidden_states, attention_mask, unused_mask=None): """ unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. indices: (total_nnz), the indices of masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. """ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( _index_first_axis(hidden_states, indices), indices, cu_seqlens, max_seqlen_in_batch, used_seqlens_in_batch, ) def _pad_input(hidden_states, indices, batch, seqlen): """ pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. batch: int, batch size for the padded sequence. seqlen: int, maximum sequence length for the padded sequence. Return: hidden_states: (batch, seqlen, ...) """ dim = hidden_states.shape[1:] output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) output[indices] = hidden_states return output.view(batch, seqlen, *dim) def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. Arguments: attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. Return: indices (`torch.Tensor`): The indices of non-masked tokens from the flattened input sequence. cu_seqlens (`torch.Tensor`): The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). max_seqlen_in_batch (`int`): Maximum sequence length in batch. """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, # this might cause a graph break max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) def _upad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: torch.Tensor, query_length: int, unpad_input_func, ): """ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. Arguments: query_layer (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). key_layer (`torch.Tensor`): Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). value_layer (`torch.Tensor`): Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. query_length (`int`): Target length. unpad_input_func: The function to use for unpadding the input tensors. Return: query_layer (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). key_layer (`torch.Tensor`): Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). value_layer (`torch.Tensor`): Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). indices_q (`torch.Tensor`): The indices of non-masked tokens from the flattened input target sequence. (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ if torch.compiler.is_compiling(): # allow PyTorch compiler to include operations that return scalar values (like .item() torch._dynamo.config.capture_scalar_outputs = True indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) # With static caches, the k/v states may be larger than the mask -> # we need to slice them to avoid generating garbage # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = _index_first_axis(key_layer, indices_k) value_layer = _index_first_axis(value_layer, indices_k) if query_length == kv_seq_len: query_layer = _index_first_axis(query_layer, indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) def _is_packed_sequence(position_ids, batch_size): """ Check the position ids whether packed sequences are indicated or not 1. Position ids exist 2. Flattened sequences only are supported 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences """ if position_ids is None: return False increasing_position_sequences = torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()