layout_bridge.py 5.9 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Bridge vLLM paged KV layout to compactor Triton kernels.

vLLM FlashAttention KV cache is shaped
  [num_blocks, block_size, num_kv_heads, head_dim].
Compactor kernels expect a flat buffer [CACHE_SIZE, head_dim] and a page table
  global_page_table[batch, kv_head, logical_page] -> physical_page_id
where each physical page holds ``block_size`` consecutive rows belonging to that
KV head only.

When num_kv_heads == 1 (MQA), a vLLM block maps 1:1 to compactor rows:
  row_index = physical_block_id * block_size + offset_in_block.

When ``num_kv_heads > 1``, we permute to head-major
``[num_kv_heads, num_blocks, block_size, head_dim]`` and flatten to
``[num_kv_heads * num_blocks * block_size, head_dim]`` so each KV head occupies
a disjoint row range in the flat buffer. The page table is built so each
logical compression page maps to ``global_row // PAGE_SIZE`` in that layout
(see ``build_page_table_head_major``).
"""

from __future__ import annotations

import torch


def _cdiv(n: int, d: int) -> int:
    return (n + d - 1) // d


def flatten_kv_cache_head_major(
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """View ``[nb, bs, H, D]`` caches as ``[H*nb*bs, D]`` in head-major order."""
    if key_cache.shape != value_cache.shape:
        raise ValueError("key_cache and value_cache must match")
    nb, bs, hkv, d = key_cache.shape
    k_hm = key_cache.permute(2, 0, 1, 3).contiguous()
    v_hm = value_cache.permute(2, 0, 1, 3).contiguous()
    k_flat = k_hm.reshape(hkv * nb * bs, d)
    v_flat = v_hm.reshape(hkv * nb * bs, d)
    return k_flat, v_flat


def write_head_major_flat_to_interleaved(
    k_flat: torch.Tensor,
    v_flat: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
) -> None:
    """Copy ``[H*nb*bs, D]`` head-major flats back to ``[nb, bs, H, D]``."""
    nb, bs, hkv, d = key_cache.shape
    k_hm = k_flat.view(hkv, nb, bs, d)
    v_hm = v_flat.view(hkv, nb, bs, d)
    key_cache.copy_(k_hm.permute(1, 2, 0, 3))
    value_cache.copy_(v_hm.permute(1, 2, 0, 3))


def build_page_table_head_major(
    block_table: torch.Tensor,
    num_kv_heads: int,
    num_blocks: int,
    block_size: int,
    page_size: int,
    max_batches: int,
) -> torch.Tensor:
    """Build ``[max_batches, H, max_chain]`` page table for head-major flat KV.

    Chains physical page ids in ``block_table`` order for each (batch, head).
    Each entry is ``global_row // page_size`` where ``global_row`` indexes rows
    in the head-major flat buffer (see ``flatten_kv_cache_head_major``).
    """
    bsz, max_blocks = block_table.shape
    if bsz > max_batches:
        raise ValueError("batch size exceeds max_batches for page table")
    num_pages_per_block = _cdiv(block_size, page_size)
    max_chain = max_blocks * num_pages_per_block
    out = torch.zeros(
        (max_batches, num_kv_heads, max_chain),
        dtype=torch.int32,
        device=block_table.device,
    )
    bt = block_table.to(torch.int64)
    for b in range(bsz):
        for h in range(num_kv_heads):
            lp_idx = 0
            for blk_i in range(max_blocks):
                bid = int(bt[b, blk_i].item())
                if bid < 0:
                    continue
                if bid >= num_blocks:
                    raise ValueError(
                        f"block_table[{b},{blk_i}]={bid} out of range "
                        f"num_blocks={num_blocks}"
                    )
                base_row = h * (num_blocks * block_size) + bid * block_size
                for p in range(num_pages_per_block):
                    start_row = base_row + p * page_size
                    if start_row >= base_row + block_size:
                        break
                    phys = start_row // page_size
                    out[b, h, lp_idx] = int(phys)
                    lp_idx += 1
    return out


def flatten_kv_cache_plane(
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """View (num_blocks, block_size, HKV, D) caches as [num_blocks*block_size*HKV, D].

    This matches compactor row indexing only when HKV == 1 (see module doc).
    """
    if num_kv_heads != 1:
        raise ValueError(
            "flatten_kv_cache_plane requires num_kv_heads==1 for compactor layout"
        )
    if key_cache.shape != value_cache.shape:
        raise ValueError("key_cache and value_cache must match")
    # [num_blocks, block_size, 1, D] -> [num_blocks * block_size, D]
    nb, bs, hkv, d = key_cache.shape
    if hkv != 1:
        raise ValueError("expected num_kv_heads==1")
    k_flat = key_cache.reshape(nb * bs, d)
    v_flat = value_cache.reshape(nb * bs, d)
    if not k_flat.is_contiguous():
        k_flat = k_flat.contiguous()
    if not v_flat.is_contiguous():
        v_flat = v_flat.contiguous()
    return k_flat, v_flat


def block_table_to_global_page_table(
    block_table: torch.Tensor,
    num_kv_heads: int,
    max_batches: int,
) -> torch.Tensor:
    """Build [max_batches, HKV, num_logical_pages] int32 page table.

    For MQA, every KV head reuses the same physical block ids as vLLM's table.
    """
    # block_table: [num_reqs_padded, max_num_blocks]
    bsz, max_lp = block_table.shape
    if bsz > max_batches:
        raise ValueError("batch size exceeds max_batches for page table")
    out = torch.zeros(
        (max_batches, num_kv_heads, max_lp),
        dtype=torch.int32,
        device=block_table.device,
    )
    bt = block_table.to(torch.int32)[:bsz]
    if num_kv_heads == 1:
        out[:bsz, 0, :max_lp] = bt
    else:
        for h in range(num_kv_heads):
            out[:bsz, h, :max_lp] = bt
    return out


def build_batch_mapping(num_reqs: int, device: torch.device) -> torch.Tensor:
    """Local batch index -> global batch row (identity)."""
    return torch.arange(num_reqs, dtype=torch.int32, device=device)