# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations import torch from vllm.platforms import current_platform def paged_k_cache_view_for_triton_gather( *, key_cache: torch.Tensor, block_size: int, ) -> torch.Tensor: """Return a KV-cache key view in [num_blocks, H, block_size, D] layout. Supports both: - [num_blocks, block_size, H, D] (typical CUDA FlashAttention v1 layout) - [num_blocks, H, block_size, D] (ROCm FlashAttention v1, or external connectors that expose the cache in HND shape) """ if key_cache.ndim != 4: raise ValueError("key_cache must be a 4D tensor.") # Common case: [B, T, H, D] -> [B, H, T, D] if int(key_cache.shape[1]) == int(block_size): return key_cache.permute(0, 2, 1, 3) # Already in [B, H, T, D] (ROCm / HND-shaped external caches). if int(key_cache.shape[2]) == int(block_size): return key_cache # Fallback: preserve historical behavior. if current_platform.is_rocm(): return key_cache return key_cache.permute(0, 2, 1, 3)