metadata_kernels.py 8.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import triton

import triton.language as tl

from loguru import logger
from typing import List, Optional
from torch.utils._triton import has_triton as has_triton_torch

from text_generation_server.utils.import_utils import (
    SYSTEM,
)
from text_generation_server.utils.log import log_master

_HAS_TRITON: Optional[bool] = None


def has_triton():
    global _HAS_TRITON
    if _HAS_TRITON is None:
        # FIXME: it seems that has_triton_torch is bugged on RocM
        #        For now, only accept cuda
        _HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
        if _HAS_TRITON:
            log_master(logger.info, "Using optimized Triton indexing kernels.")

    return _HAS_TRITON


def block_tables_to_padded(
    max_blocks: int,
    cu_seqlen: torch.Tensor,
    block_tables: torch.Tensor,
    block_tables_ragged: torch.Tensor,
):
    def grid(meta):
        return (
            triton.cdiv(max_blocks, meta["BLOCK_SIZE"]),
            len(block_tables),
        )

    triton_block_tables_to_padded[grid](
        cu_seqlen,
        block_tables,
        block_tables_ragged,
        block_tables.shape[1],
        BLOCK_SIZE=256,
    )


def block_tables_to_ragged(
    *,
    block_tables: torch.Tensor,
    input_lengths: List[int],
    cache_lengths: List[int],
    input_lengths_tensor: torch.Tensor,
    cache_lengths_tensor: torch.Tensor,
58
    max_current_length: int,
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
) -> torch.Tensor:
    """Convert block table to ragged format compatible with FlashInfer."""
    assert len(input_lengths) == len(cache_lengths)

    total_len = sum(input_lengths) + sum(cache_lengths)
    block_tables_ragged = torch.empty(
        total_len, dtype=torch.int32, device=block_tables.device
    )

    if has_triton():
        cu_seqlen = torch.nn.functional.pad(
            torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
        )

        def grid(meta):
            return (
                triton.cdiv(max_current_length, meta["BLOCK_SIZE"]),
                len(cache_lengths),
            )

        triton_block_tables_to_ragged[grid](
            cu_seqlen,
            block_tables,
            block_tables_ragged,
            block_tables.shape[1],
            BLOCK_SIZE=256,
        )
    else:
        offset = 0
        for i, (input_length, cache_length) in enumerate(
            zip(input_lengths, cache_lengths)
        ):
            seq_len = cache_length + input_length
            block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
            offset += seq_len

    return block_tables_ragged


def copy_next_input_ids_inplace(
    max_next_input_ids: int,
    all_input_ids: torch.Tensor,
    cache_lengths: torch.Tensor,
    input_lengths: torch.Tensor,
    prompt_lengths: torch.Tensor,
    next_input_ids: torch.Tensor,
    cu_accepted_ids: torch.Tensor,
):
    def grid(meta):
        return (
            triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]),
            len(all_input_ids),
        )

    triton_copy_next_input_ids_inplace[grid](
        all_input_ids,
        cache_lengths,
        input_lengths,
        prompt_lengths,
        next_input_ids,
        cu_accepted_ids,
        all_input_ids.shape[1],
        BLOCK_SIZE=16,
    )


def prepare_position_slot_ids(
    max_input_length: int,
    cache_lengths: torch.Tensor,
    cu_seqlen: torch.Tensor,
    cu_slots: torch.Tensor,
    position_ids: torch.Tensor,
    slot_indices: torch.Tensor,
):
    def grid(meta):
        return (
            triton.cdiv(max_input_length, meta["BLOCK_SIZE"]),
            len(cache_lengths),
        )

    triton_prepare_position_slot_ids[grid](
        cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
    )


def slots_filtering(
    max_slots: int,
    slots: torch.Tensor,
    filtered_slots: torch.Tensor,
    cu_slots: torch.Tensor,
    slots_start: torch.Tensor,
):
    def grid(meta):
        return (
            triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
            len(slots_start),
        )

    triton_slots_filtering[grid](
        slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
    )


@triton.jit
def triton_slots_filtering(
    # Inputs
    slots_ptr,
    filtered_slots_ptr,
    slots_start_ptr,
    cu_slots_ptr,
    # Const values
    BLOCK_SIZE: "tl.constexpr",
):
    # Position in block_tables_ragged.numel() / BLOCK_SIZE
    pid = tl.program_id(axis=0)
    # Position in batch
    bid = tl.program_id(axis=1)

    block_start = pid * BLOCK_SIZE
    block_arange = block_start + tl.arange(0, BLOCK_SIZE)

    filter_start = tl.load(slots_start_ptr + bid)

    slot_start = tl.load(cu_slots_ptr + bid)
    slot_end = tl.load(cu_slots_ptr + bid + 1)

    mask = (slot_start + block_arange) < slot_end

    slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
    tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)


@triton.jit
def triton_block_tables_to_padded(
    # Inputs
    cu_seqlen_ptr,
    # Outputs
    block_tables_ptr,
    block_tables_ragged_ptr,
    # Stride
    stride_block_tables,
    # Const values
    BLOCK_SIZE: "tl.constexpr",
):
    # Position in block_tables_ragged.numel() / BLOCK_SIZE
    pid = tl.program_id(axis=0)
    # Position in batch
    bid = tl.program_id(axis=1)

    block_start = pid * BLOCK_SIZE
    block_arange = block_start + tl.arange(0, BLOCK_SIZE)

    seq_start = tl.load(cu_seqlen_ptr + bid)
    seq_end = tl.load(cu_seqlen_ptr + bid + 1)

    mask = (seq_start + block_arange) < seq_end

    blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask)
    tl.store(
        block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask
    )


@triton.jit
def triton_block_tables_to_ragged(
    # Inputs
    cu_seqlen_ptr,
    # Outputs
    block_tables_ptr,
    block_tables_ragged_ptr,
    # Stride
    stride_block_tables,
    # Const values
    BLOCK_SIZE: "tl.constexpr",
):
    # Position in block_tables_ragged.numel() / BLOCK_SIZE
    pid = tl.program_id(axis=0)
    # Position in batch
    bid = tl.program_id(axis=1)

    block_start = pid * BLOCK_SIZE
    block_arange = block_start + tl.arange(0, BLOCK_SIZE)

    seq_start = tl.load(cu_seqlen_ptr + bid)
    seq_end = tl.load(cu_seqlen_ptr + bid + 1)

    mask = (seq_start + block_arange) < seq_end

    blocks = tl.load(
        block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask
    )
    tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask)


@triton.jit
def triton_copy_next_input_ids_inplace(
    # Inputs
    all_input_ids_ptr,
    cache_lengths_ptr,
    input_lengths_ptr,
    prompt_lengths_ptr,
    next_input_ids_ptr,
    cu_accepted_ids_ptr,
    # Stride
    stride_all_input_ids,
    # Const values
    BLOCK_SIZE: "tl.constexpr",
):
    # Position in max_accepted_ids / BLOCK_SIZE
    pid = tl.program_id(axis=0)
    # Position in batch
    bid = tl.program_id(axis=1)

    block_start = pid * BLOCK_SIZE
    block_arange = block_start + tl.arange(0, BLOCK_SIZE)

    # Used for correctly indexing in all_input_ids
    cache_length = tl.load(cache_lengths_ptr + bid)
    input_length = tl.load(input_lengths_ptr + bid)
    prompt_length = tl.load(prompt_lengths_ptr + bid)

    # Start/End of next_input_ids for this request
    next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid)
    next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1)

    # Mask values out of range
    mask = (next_input_ids_start + block_arange) < next_input_ids_end

    # Mask values for request still prefilling
    decode_mask = (cache_length + input_length + block_arange) >= prompt_length

    mask = mask & decode_mask

    # Load this request next input ids
    next_input_ids = tl.load(
        next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask
    )

    # Store in all_input_ids, since it is a 2D tensor, apply stride * bid
    tl.store(
        all_input_ids_ptr
        + stride_all_input_ids * bid
        + cache_length
        + input_length
        + block_arange,
        next_input_ids,
        mask=mask,
    )


@triton.jit
def triton_prepare_position_slot_ids(
    # Inputs
    cache_lengths_ptr,
    cu_seqlen_ptr,
    cu_slots_ptr,
    # Outputs
    position_ids_ptr,
    slot_indices_ptr,
    # Const values
    BLOCK_SIZE: "tl.constexpr",
):
    # Position in max_input_length / BLOCK_SIZE
    pid = tl.program_id(axis=0)
    # Position in batch
    bid = tl.program_id(axis=1)

    block_start = pid * BLOCK_SIZE
    block_arange = block_start + tl.arange(0, BLOCK_SIZE)

    cache_length = tl.load(cache_lengths_ptr + bid)

    seq_start = tl.load(cu_seqlen_ptr + bid)
    seq_end = tl.load(cu_seqlen_ptr + bid + 1)

    slot_start = tl.load(cu_slots_ptr + bid)

    mask = (seq_start + block_arange) < seq_end

    tl.store(
        position_ids_ptr + seq_start + block_arange,
        cache_length + block_arange,
        mask=mask,
    )
    tl.store(
        slot_indices_ptr + seq_start + block_arange,
        slot_start + cache_length + block_arange,
        mask=mask,
    )