test_block_table.py 4.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# SPDX-License-Identifier: Apache-2.0
import os

import neuronxcc.nki.language as nl
import pytest
import torch
import torch.nn.functional as F
from neuronxcc import nki

from vllm.attention.ops.nki_flash_attn import (
    load_block_tables, transform_block_tables_for_indirect_load)


def is_power_of_2(n):
    return n > 0 and (n & (n - 1) == 0)


def nki_load_and_transform_block_tables(
    block_tables,
    num_tiles,
    num_blocks_per_tile,
    num_head,
    head_id,
    block_size_tiling_factor,
):
    assert is_power_of_2(
        num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
    block_tables_sbuf = load_block_tables(block_tables, num_tiles,
                                          num_blocks_per_tile)

    # we need to pass an Index as head_id
    head_id = nl.arange(1)[None, :] + head_id

    block_tables_transposed = transform_block_tables_for_indirect_load(
        block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
    B_P_SIZE = 128
    assert block_tables_transposed.shape[1] == B_P_SIZE

    out = nl.ndarray(
        block_tables_transposed.shape,
        dtype=nl.int32,
        buffer=nl.shared_hbm,
    )
    for i in nl.affine_range(block_tables_transposed.shape[0]):
        nl.store(dst=out[i], value=block_tables_transposed[i])
    return out


def ref_block_tables_transform(
    block_tables,
    num_tiles,
    num_blocks_per_tile,
    num_head,
    head_id,
    block_size_tiling_factor,
):
    assert block_tables.numel() == num_tiles * num_blocks_per_tile
    block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
    B_F_SIZE = 128
    num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
    block_tables = F.pad(
        block_tables,
        (0, 0, 0, num_tiles_padded - num_tiles),
        "constant",
        0,
    )

    block_tables = block_tables * num_head + head_id
    block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
    offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
    block_tables = block_tables * block_size_tiling_factor + offset
    block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()

    num_blocks_per_tile = block_tables_transposed.shape[0]
    assert num_blocks_per_tile % B_F_SIZE == 0
    return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
                                        B_F_SIZE, num_tiles_padded)


@pytest.mark.parametrize(
    "q_head_per_kv_head,head_id",
    [
        (1, 0),
        (3, 1),
    ],
)
@pytest.mark.parametrize(
    "num_tiles,num_blocks_per_tile",
    [
        (1, 1),
        (13, 16),
        (17, 128),
        (35, 512),
        (128, 128),
        (130, 64),
        (280, 256),
        (315, 1),
    ],
)
@torch.inference_mode()
def test_load_and_transform_block_tables(
    num_tiles,
    num_blocks_per_tile,
    q_head_per_kv_head,
    head_id,
) -> None:
    import torch_xla.core.xla_model as xm

    device = xm.xla_device()

    compiler_flags = [
        "-O1",
        "--retry_failed_compilation",
    ]
    compiler_flags_str = " ".join(compiler_flags)
    os.environ["NEURON_CC_FLAGS"] = compiler_flags_str

    torch.manual_seed(10000)
    torch.set_printoptions(sci_mode=False)

    # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
    B_P_SIZE = 128
    if num_blocks_per_tile < B_P_SIZE:
        assert B_P_SIZE % num_blocks_per_tile == 0
        block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
    else:
        block_size_tiling_factor = 1
    max_num_blocks = 100000
    block_tables = torch.randint(
        0,
        max_num_blocks,
        (num_tiles * num_blocks_per_tile, ),
        dtype=torch.int32,
    )
    nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
        block_tables.to(device=device),
        num_tiles,
        num_blocks_per_tile,
        q_head_per_kv_head,
        head_id,
        block_size_tiling_factor,
    ).cpu()
    ref_out = ref_block_tables_transform(
        block_tables,
        num_tiles,
        num_blocks_per_tile,
        q_head_per_kv_head,
        head_id,
        block_size_tiling_factor,
    )
    assert (nki_out.shape == ref_out.shape
            ), f"{nki_out.shape=} != {ref_out.shape=}"
    assert torch.all(nki_out == ref_out)