kv_cache_view.py 1.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

import torch

from vllm.platforms import current_platform


def paged_k_cache_view_for_triton_gather(
    *,
    key_cache: torch.Tensor,
    block_size: int,
) -> torch.Tensor:
    """Return a KV-cache key view in [num_blocks, H, block_size, D] layout.

    Supports both:
    - [num_blocks, block_size, H, D] (typical CUDA FlashAttention v1 layout)
    - [num_blocks, H, block_size, D] (ROCm FlashAttention v1, or external
      connectors that expose the cache in HND shape)
    """
    if key_cache.ndim != 4:
        raise ValueError("key_cache must be a 4D tensor.")

    # Common case: [B, T, H, D] -> [B, H, T, D]
    if int(key_cache.shape[1]) == int(block_size):
        return key_cache.permute(0, 2, 1, 3)

    # Already in [B, H, T, D] (ROCm / HND-shaped external caches).
    if int(key_cache.shape[2]) == int(block_size):
        return key_cache

    # Fallback: preserve historical behavior.
    if current_platform.is_rocm():
        return key_cache
    return key_cache.permute(0, 2, 1, 3)