utils.py 34.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Kernel test utils"""

5
6
import itertools
import random
7
from collections.abc import Sequence
8
from numbers import Number
9
from typing import Any, NamedTuple
10
from unittest.mock import patch
11
12

import torch
13
from torch._prims_common import TensorLikeType
14

bnellnm's avatar
bnellnm committed
15
from tests.kernels.quant_utils import native_w8a8_block_matmul
16
from vllm.model_executor.custom_op import op_registry
17
from vllm.model_executor.layers.activation import SiluAndMul
18
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
19
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
20
from vllm.utils.torch_utils import make_tensor_with_pad
21
from vllm.v1.attention.backend import AttentionType
22

23
24
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
25
DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = (
26
27
28
29
30
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
)

31
ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
32
33
34
35
36
37
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
    "test_aot_dispatch_dynamic",
)

38

39
class QKVInputs(NamedTuple):
40
    """
41
    Data structure for representing unpacked attention inputs,
42
43
44
45
    query/key/values and their sequence lengths.

    Attributes:

46
        * {query,key,value}: unpacked (batch_size x padded_seq_len x
47
48
49
                             num_heads x head_size) attention inputs
        * q_seq_lens: query sequence lengths list
        * kv_seq_lens: shared key/value sequence lengths list
50
    """
51
52
53
54

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
55
56
    q_seq_lens: list[int]
    kv_seq_lens: list[int]
57
58
59


class QKVO(NamedTuple):
60
    """
61
    Data structure for representing unpacked attention inputs,
62
63
64
65
    alongside unpacked known-correct attention output

    Attributes:

66
        * qkv: unpacked (batch_size x padded_seq_len x
67
                             num_heads x head_size) attention inputs
68
        * ideal_output: unpacked (batch_size x padded_seq_len x
69
                        num_heads x head_size) known-correct attention output
70
    """
71
72
73
74
75
76

    qkv: QKVInputs
    ideal_output: torch.Tensor


class PackedQKVInputs(NamedTuple):
77
    """
78
79
80
81
    Data structure for representing packed attention inputs

    Attributes:

82
        * {query,key,value}: packed (number_of_tokens x num_heads
83
84
85
86
87
88
                             x head_size) attention inputs
        * q_start_loc_list: list of query start locations within packed tensor
        * kv_start_loc_list: shared list of key/value start locations within
                             packed tensor
        * q_seq_lens: query sequence lengths list
        * kv_seq_lens: shared key/value sequence lengths list
89
    """
90
91
92
93

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
94
95
96
97
    q_start_loc_list: list[int] | None
    kv_start_loc_list: list[int] | None
    q_seq_lens: list[int] | None
    kv_seq_lens: list[int] | None
98
99
100


class PackedQKVO(NamedTuple):
101
    """
102
    Data structure for representing packed attention inputs,
103
104
105
106
    alongside packed known-correct attention output

    Attributes:

107
        * packed_qkv: packed (number_of_tokens x num_heads
108
                      x head_size) attention inputs
109
        * ideal_output: packed (number_of_tokens x num_heads
110
                        x head_size) known-correct attention output
111
    """
112

113
    packed_qkv: PackedQKVInputs | None
114
115
116
117
    ideal_output: torch.Tensor


class KVMemoryMap(NamedTuple):
118
    """
119
120
121
122
123
124
    Data structure for encapsulating KV cache memory mapping.

    Attributes:

        * block_tables: KV cache block tables
        * slot_mapping: mapping of sequence offset to physical address
125
    """
126
127
128
129
130
131

    block_tables: torch.Tensor
    slot_mapping: torch.Tensor


class PhaseTestParameters(NamedTuple):
132
    """
133
134
135
136
137
138
    Data structure for encapsulating the test parameters
    for a given test "phase" (prefill or decode phase) and attention
    scenario (encoder, decoder-self, encoder/decoder-cross)

    Attributes:

139
        * packed_qkvo: packed (number_of_tokens x num_heads
140
141
142
143
                       x head_size) attention inputs & known-correct
                       output
        * kv_mmap: KV cache memory mapping, specific to this test phase &
                   attention scenario
144
    """
145
146

    packed_qkvo: PackedQKVO
147
    kv_mmap: KVMemoryMap | None
148
149
150


def maybe_make_int_tensor(
151
152
    _list: list[int] | None,
    device: torch.device | str,
153
) -> torch.Tensor:
154
    """
155
156
157
158
159
160
    Convert Python int list to a 1D int torch.Tensor on `device`

    Returns:

    * If _list is not None: 1D int torch.Tensor on `device`
    * None otherwise
161
162
163
164
    """
    return (
        None if _list is None else torch.tensor(_list, dtype=torch.int, device=device)
    )
165
166
167


def maybe_make_long_tensor(
168
169
    _list: list[int] | None,
    device: torch.device | str,
170
) -> torch.Tensor:
171
    """
172
173
174
175
176
177
    Convert Python int list to a 1D long torch.Tensor on `device`

    Returns:

    * If _list is not None: 1D long torch.Tensor on `device`
    * None otherwise
178
179
180
181
    """
    return (
        None if _list is None else torch.tensor(_list, dtype=torch.long, device=device)
    )
182
183


184
def maybe_max(_list: list | None) -> Number | None:
185
    """
186
187
188
189
    Returns:

    * If _list is not None: max(_list)
    * None otherwise
190
    """
191
192
193
194
195
196
197
    return None if _list is None else max(_list)


def make_causal_mask(
    q_max_seq_len: int,
    kv_max_seq_len: int,
) -> torch.Tensor:
198
    """
199
200
201
    Create a q_max_seq_len x kv_max_seq_len causal mask

    Arguments:
202

203
204
205
206
207
208
    * q_max_seq_len: query max seq len
    * kv_max_seq_len: key/value max seq len

    Returns:

    * 2D tensor, q_max_seq_len x kv_max_seq_len
209
    """
210
211
212
213

    # Create a matrix where entry (i, j) is True if i >= j
    mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
    # Replace True with float('-inf') and False with 0
214
    mask = mask.masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0)
215
216
217
    return mask


218
219
220
221
222
def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
223
224
225
    custom_mask: torch.Tensor | None = None,
    q_seq_lens: list | None = None,
    kv_seq_lens: list | None = None,
226
227
) -> torch.Tensor:
    """
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    "Golden" masked attention reference. Supports two types of masking:

    * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
      padding elements
    * Custom attention mask, which can force an arbitrary mask tensor, i.e.
      causal

    Arguments:

    * query: batch_size x q_padded_seq_len x num_heads x head_size
    * key: batch_size x kv_padded_seq_len x num_heads x head_size
    * value: batch_size x kv_padded_seq_len x num_heads x head_size
    * scale: Attention scale factor
    * custom_mask: custom attention mask; good place to inject a causal
      attention mask
    * q_seq_lens: list of unpadded query seq_lens for each batch index
    * kv_seq_lens: list of unpadded key/value seq_lens for each batch index

    Returns:

    * Attention result, batch_size x q_padded_seq_len x num_heads x head_size
249
    """
250
251
252
253
254

    assert q_seq_lens is not None
    assert kv_seq_lens is not None

    batch_size = query.shape[0]
255
256
    assert len(q_seq_lens) == batch_size
    assert len(kv_seq_lens) == batch_size
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

    attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()

    # Basic attention mask, derived from seq lens
    if (q_seq_lens is not None) or (kv_seq_lens is not None):
        attn_mask = torch.zeros_like(attn_weights)
        if q_seq_lens is not None:
            for bdx, plen in enumerate(q_seq_lens):
                attn_mask[bdx, :, plen:, :] = -torch.inf
        if kv_seq_lens is not None:
            for bdx, plen in enumerate(kv_seq_lens):
                attn_mask[bdx, :, :, plen:] = -torch.inf

        attn_weights = attn_weights + attn_mask.float()

    # Custom attention mask
    if custom_mask is not None:
        attn_weights = attn_weights + custom_mask.float()

    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
    return out


def make_qkv(
    batch_size: int,
    max_q_seq_len: int,
284
    max_kv_seq_len: int | None,
285
286
    num_heads: int,
    head_size: int,
287
288
    device: torch.device | str,
    force_kv_seq_lens: list[int] | None = None,
289
290
    attn_type: AttentionType = AttentionType.ENCODER_DECODER,
    force_max_len: bool = False,
291
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
292
    """
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    Construct QKV test tensors for self- and cross-attention.

    Generates three query/key/value triplets:

    * "Baseline" query/key/value (for input to reference attention function)
    * "Prefill" query/key/value (last sequence offset zero'd out, for use as
      input to prefill kernel)
    * "Decode" query/key/value (only the last sequence offset  from baseline,
      for use as input to decode kernel)

    Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
    seqlens

    Arguments:

    * batch_size
    * max_q_seq_len: max query seq len
    * max_kv_seq_len: max key/value seq len
    * num_heads
    * head_size
313
314
315
    * is_encoder_decoder_attn: if True, query seqlen may differ from
      key/value seqlen (as is often the case for cross-attention);
      o/w, query/key/value seqlens match at each batch index
316
317
318
319
320
321
322
323
324
325
326
327
328
      (max_kv_seq_len is unused)
    * force_kv_seq_lens: if not None, overrides kv sequence lengths
    * attn_type: encoder, decoder self, or enc/dec cross attention
    * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
      seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
      and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
    * device: CPU or CUDA device

    Returns:

    * Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
    * Prefill QKVInputs structure (containing all but the last sequence offset)
    * Decode QKVInputs structure (containing all only the last sequence offset)
329
    """
330
331
332
333

    if force_max_len:
        q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
    else:
334
        q_seq_lens = [random.randint(2, max_q_seq_len) for _ in range(batch_size)]
335
336
337
338
339
340
341
342
343
344
345
346
    kv_seq_lens = None
    if force_kv_seq_lens is not None:
        kv_seq_lens = force_kv_seq_lens
    elif attn_type != AttentionType.ENCODER_DECODER:
        # K,V seq lens match Q for self-attention
        kv_seq_lens = q_seq_lens
    else:
        # K,V seq lens are distinct from Q seq lens & random
        assert max_kv_seq_len is not None
        if force_max_len:
            kv_seq_lens = [max_kv_seq_len] * batch_size
        else:
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            kv_seq_lens = [random.randint(2, max_kv_seq_len) for _ in range(batch_size)]

    query = torch.rand((batch_size, max_q_seq_len, num_heads, head_size)).to(device)
    key = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
    value = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device)

    prefill_query = torch.zeros((batch_size, max_q_seq_len, num_heads, head_size)).to(
        device
    )
    prefill_key = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to(
        device
    )
    prefill_value = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to(
        device
    )

    decode_query = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
364
    decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
365
    decode_value = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
366

367
    for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)):
368
369
370
371
        query[bdx, q_seq_len:, :, :] = 0
        key[bdx, kv_seq_len:, :, :] = 0
        value[bdx, kv_seq_len:, :, :] = 0

372
373
374
375
376
377
378
379
380
381
382
383
384
        prefill_query[bdx, 0 : (q_seq_len - 1), :, :] = query[
            bdx, 0 : (q_seq_len - 1), :, :
        ]
        prefill_key[bdx, 0 : (kv_seq_len - 1), :, :] = key[
            bdx, 0 : (kv_seq_len - 1), :, :
        ]
        prefill_value[bdx, 0 : (kv_seq_len - 1), :, :] = value[
            bdx, 0 : (kv_seq_len - 1), :, :
        ]

        decode_query[bdx, :, :, :] = query[bdx, (q_seq_len - 1) : q_seq_len, :, :]
        decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1) : kv_seq_len, :, :]
        decode_value[bdx, :, :, :] = value[bdx, (kv_seq_len - 1) : kv_seq_len, :, :]
385
386
387
388
389
390
391
392
393
394
395
396
397

    prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
    prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]

    decode_q_seq_lens = [1 for _ in q_seq_lens]
    decode_kv_seq_lens = [1 for _ in kv_seq_lens]

    return (
        QKVInputs(
            query,  # Overall QKV inputs
            key,
            value,
            q_seq_lens,
398
399
            kv_seq_lens,
        ),
400
401
402
403
404
        QKVInputs(
            prefill_query,  # Prefill subset of QKV sequences
            prefill_key,
            prefill_value,
            prefill_q_seq_lens,
405
406
            prefill_kv_seq_lens,
        ),
407
408
409
410
411
        QKVInputs(
            decode_query,  # Decode subset of KV sequences
            decode_key,
            decode_value,
            decode_q_seq_lens,
412
413
414
            decode_kv_seq_lens,
        ),
    )
415
416
417


def pack_tensor(
418
    unpacked_tensor: torch.Tensor, seq_lens: list[int], device: torch.device | str
419
420
) -> tuple[torch.Tensor, list[int]]:
    """
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
    unpadded number_of_tokens x num_heads x head_size tensor, where
    number_of_tokens = sum(seq_lens)

    Arguments:

    * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
    * seq_lens: list of token counts for each seq
    * device: CPU or CUDA device

    Returns

    * packed_tensor: number_of_tokens x num_heads x head_size
    * start_loc_list: start idx of each batch elt in packed_tensor; [0] +
      list(itertools.accumulate(seq_lens))
436
    """
437
438
439
440
441
442
443
444

    num_tok = sum(seq_lens)
    num_heads = unpacked_tensor.shape[-2]
    head_size = unpacked_tensor.shape[-1]
    start_loc_list = [0] + list(itertools.accumulate(seq_lens))
    packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)

    for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
445
446
447
        packed_tensor[start_loc : (start_loc + seq_len), :, :] = unpacked_tensor[
            bdx, :seq_len, :, :
        ]
448
449
450
451

    return packed_tensor, start_loc_list


452
def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
453
    """
454
455
456
    Individually pack each of Q, K and V, each with dimensions batch_size x
    padded_seq_len x num_heads x head_size, into respective number_of_tokens x
    num_heads x head_size tensors.
457

458
459
460
461
462
463
464
465
466
467
468
469
470
471
    For Q, number_of_tokens = sum(q_seq_lens).

    For K and V, number_of_tokens = sum(kv_seq_lens)

    Arguments:

    * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
           attention inputs
    * device: CPU or CUDA device

    Returns

    * Packed (number_of_tokens x num_heads x head_size) QKV inputs
      derived from unpacked inputs
472
    """
473
474
475
476
477

    if qkv.query is None:
        packed_query = None
        q_start_loc_list = None
    else:
478
479
480
481
        packed_query, q_start_loc_list = pack_tensor(
            qkv.query, qkv.q_seq_lens, device=device
        )
    packed_key, kv_start_loc_list = pack_tensor(qkv.key, qkv.kv_seq_lens, device=device)
482
483
    packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
    return PackedQKVInputs(
484
485
486
487
        packed_query,
        packed_key,
        packed_value,
        q_start_loc_list,
488
489
        kv_start_loc_list,
        (None if q_start_loc_list is None else qkv.q_seq_lens),
490
491
        qkv.kv_seq_lens,
    )
492
493
494


def _make_metadata_tensors(
495
496
497
498
    seq_lens: list[int] | None,
    context_lens: list[int] | None,
    encoder_seq_lens: list[int] | None,
    device: torch.device | str,
499
500
501
502
503
) -> tuple[
    torch.Tensor,
    torch.Tensor,
    Any,
    Any,
504
    torch.Tensor | None,
505
506
    torch.Tensor,
    torch.Tensor,
507
    int | None,
508
509
]:
    """
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    Build scalar & tensor values required to build attention metadata structure.

    Arguments:

    * seq_lens: list of token-counts for each decoder input seq
    * context_lens: list of context length values for each seq
    * encoder_seq_lens: list of token-counts for each encoder input seq
    * device: CPU or CUDA device

    Returns:

    * seq_lens_tensor: decoder seq_lens list, as tensor
    * context_lens_tensor: context_lens list, as tensor
    * max_context_len: max(context_lens)
    * max_seq_len: max(seq_lens)
    * seq_start_loc: start idx of each sequence
526
527
    * encoder_seq_lens_tensor: encoder seq_lens list, as tensor
    * encoder_seq_start_loc: start idx of each encoder sequence
528
    * max_encoder_seq_len: encoder seq_lens list, as tensor
529
    """
530
531
532
533
534
535
    seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
    context_lens_tensor = maybe_make_int_tensor(context_lens, device)
    max_context_len = maybe_max(context_lens)
    max_seq_len = maybe_max(seq_lens)

    encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
536
    max_encoder_seq_len = None if encoder_seq_lens is None else max(encoder_seq_lens)
537
538
539

    seq_start_loc = None

540
    if seq_lens_tensor is not None:
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        seq_start_loc = torch.zeros(
            seq_lens_tensor.shape[0] + 1,
            dtype=torch.int32,
            device=seq_lens_tensor.device,
        )
        torch.cumsum(
            seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]
        )

    encoder_seq_start_loc = torch.zeros(
        encoder_seq_lens_tensor.shape[0] + 1,
        dtype=torch.int32,
        device=encoder_seq_lens_tensor.device,
    )
    torch.cumsum(
        encoder_seq_lens_tensor,
        dim=0,
        dtype=encoder_seq_start_loc.dtype,
        out=encoder_seq_start_loc[1:],
    )

    return (
        seq_lens_tensor,
        context_lens_tensor,
        max_context_len,
        max_seq_len,
        seq_start_loc,
        encoder_seq_lens_tensor,
        encoder_seq_start_loc,
        max_encoder_seq_len,
    )


def make_kv_cache(
    num_blocks: int,
    num_heads: int,
    head_size: int,
    block_size: int,
579
    device: torch.device | str,
580
581
582
583
    backend: str,
    default_val: float = 0.0,
) -> torch.Tensor:
    """
584
585
586
587
588
589
590
591
592
593
594
595
596
    Create a fake KV cache.

    Arguments:

    * num_blocks: number of blocks in the KV cache
    * num_heads: number of attention heads
    * head_size: head dimension
    * block_size: number of offsets within a block
    * device: CPU or CUDA device
    * default_val: initialization value for KV cache elements

    Returns:

597
    * kv_cache: 2 x num_blocks x block_size x num_heads x head_size
598
    *     for backend 'FLASH_ATTN'
599
    """
600
601
602
    if backend != "FLASH_ATTN":
        raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
    kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(device)
603
604
605
606
607
608
    if default_val is not None:
        kv_cache[:, :, :] = default_val
    return kv_cache


def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
609
    """
610
611
    Compute the minimum number of blocks required to hold num_tokens tokens,
    given block_size
612
    """
613
    return (num_tokens + block_size - 1) // block_size
614
615


616
def make_empty_slot_mapping_tensor(device: torch.device | str):
617
618
619
    return maybe_make_long_tensor([], device)


620
def make_empty_block_tables_tensor(device: torch.device | str):
621
622
623
    return torch.tensor([], device=device)


624
625
626
def split_slot_mapping(
    slot_mapping_list: torch.Tensor,
    seq_lens: list[int],
627
    device: torch.device | str,
628
629
):
    """
630
631
632
633
634
    Split a slot mapping into valid prefill- and decode-phase slot mappings.

    Context:
    * Your goal is to test (1) prefill of N prompts, with prompt-lengths
      {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
635
      for all N prompts (N tokens total); the resultant sequence lengths
636
      after decode would be {K_i + 1 for i \\in [0,N)}
637
638
639
    * The test you want to do requires (1) having the prefill slot mapping
      for all tokens present during prefill, the number of which is
      M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
640
      decoded tokens
641
642

    This function consumes a single 1D slot mapping, which is the
643
644
645
646
647
    concatenation of N slot mappings each of length K_i + 1 (corresponding
    to the  sequence lengths after decode), with a total length of
    P = \\sum_i{K_i + 1} = M + N

    The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
648
    from each of the N subsequences in the slot mapping (i.e. omitting the
649
650
651
652
653
654
    decoded token's mapping.)

    The N excised entries are appended to obtain the decode-phase slot mapping

    Arguments:

655
    * slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
656
      post-decode sequences
657
    * seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
658
659
660
661
662
      description above)
    * device: cuda, cpu, etc.

    Returns:

663
    * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
664
      reflecting all N prefill prompts
665
    * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
666
      all N decoded tokens
667
    """
668
669
670
671
672
673

    prefill_slot_mapping = []
    decode_slot_mapping = []

    base_idx = 0
    for seq_len in seq_lens:
674
675
676
        prefill_slot_mapping.extend(
            slot_mapping_list[base_idx : (base_idx + seq_len - 1)]
        )
677
678
679
        decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
        base_idx += seq_len

680
681
682
683
    return (
        maybe_make_long_tensor(prefill_slot_mapping, device),
        maybe_make_long_tensor(decode_slot_mapping, device),
    )
684
685
686


def make_block_tables_slot_mapping(
687
688
    block_size: int,
    seq_lens: list[int],
689
    device: torch.device | str,
690
691
692
    block_base_addr: int = 0,
) -> tuple[torch.Tensor, list[int], int]:
    """
693
694
695
696
697
    Construct fake block tables & slot mappings.

    For a sequence with num_tokens tokens the minimum number
    of required KV cache blocks is

698
    num_blocks = (num_tokens + block_size - 1) // block_size
699
700
701

    Then the minimum KV cache size in blocks is

702
    total_cache_blocks = sum(num_blocks for all seqs)
703
704
705
706
707
708
709
710

    Then, the blocktable mapping counts downward from

    block_base_addr + total_cache_blocks

    to

    block_base_addr
711

712
713
714
715
716
717
718
719
720
721
722
723
724
725

    The constructed block-tables and slot-mapping are sized to the
    lengths of the sequences in their entirety (as reflected by seq_lens),
    i.e. the total of prefill prompt tokens + decoded tokens.

    Arguments:

    * block_size: number of offsets per block
    * seq_lens: list of token-counts for each sequence
    * block_base_addr: the block table base address
    * device: CPU or CUDA device

    Return:

726
    * block_tables_tensor: block table for sequence
727
728
    * slot_mapping_list: slot mapping for sequence
    * max_block_idx: the highest block address within this block table
729
    """
730
731
732

    # Provision minimum number of KV cache blocks
    num_blocks_list = [
733
        _num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens
734
735
736
737
738
739
740
741
742
743
744
745
    ]
    max_block_table_len = max(num_blocks_list)
    block_table_pad_tokens = 10

    block_tables = []
    slot_mapping_list = []
    # Compute uppermost address of block table
    total_cache_blocks = sum(num_blocks_list)
    block_base_idx = block_base_addr + total_cache_blocks
    max_block_idx = block_base_idx
    for sdx, num_tokens in enumerate(seq_lens):
        num_blocks = num_blocks_list[sdx]
746
        block_table = list(range(block_base_idx, block_base_idx - num_blocks, -1))
747
        for idx in range(num_tokens):
748
749
750
            mapping_value = (idx % block_size) + block_table[
                idx // block_size
            ] * block_size
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
            slot_mapping_list.append(mapping_value)

        block_base_idx -= num_blocks
        block_tables.append(block_table)

    block_tables_tensor = make_tensor_with_pad(
        block_tables,
        max_len=max_block_table_len + block_table_pad_tokens,
        pad=0,
        dtype=torch.int,
        device=device,
    )

    return (block_tables_tensor, slot_mapping_list, max_block_idx)


767
768
769
770
def assert_actual_matches_ideal(
    test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str
) -> None:
    """
771
772
773
774
775
776
777
    Assert that observed output matches the ideal output
    contained in the test parameters data structure.

    Arguments:

    * test_params: Test parameters including packed ideal output
    * output_under_test: actually observed output value
778
    """
779
    ideal_output = test_params.packed_qkvo.ideal_output
780
781
782
783
784
785
786
787
    if backend != "FLASH_ATTN":
        raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
    # For FlashAttention override the accuracy thresholds to non default
    # values since we notice a higher difference between the ideal and
    # actual output.
    torch.testing.assert_close(
        ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
    )
788
789


790
791
792
793
794
795
796
797
798
799
800
# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    equal_nan: bool = False,
) -> bool:
    """
    Reference implementation of torch.allclose
    """
801
    torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
802
803
804

    return bool(
        torch.all(
805
806
807
808
809
            torch.isclose(
                a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
            )
        ).item()
    )
810
811


812
813
814
# Marlin MoE test utils


815
def stack_and_dev(tensors: list[torch.Tensor]):
816
817
818
819
820
821
    dev = tensors[0].device
    return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
822
823
        torch.abs(output_ref)
    )
824
825


bnellnm's avatar
bnellnm committed
826
827
828
829
830
831
832
def torch_experts(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    global_num_experts: int = -1,
833
834
835
836
837
838
839
840
    b_bias1: torch.Tensor | None = None,
    b_bias2: torch.Tensor | None = None,
    expert_map: torch.Tensor | None = None,
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
841
    per_act_token_quant=False,
842
    block_shape: list[int] | None = None,
843
    apply_router_weights_on_input: bool = False,
844
    activation: MoEActivation = MoEActivation.SILU,
bnellnm's avatar
bnellnm committed
845
) -> torch.Tensor:
846
847
848
849
850
    assert (
        global_num_experts == -1
        or (global_num_experts == w1.shape[0] and expert_map is None)
        or (expert_map is not None and global_num_experts == expert_map.shape[0])
    )
bnellnm's avatar
bnellnm committed
851

852
853
854
855
856
857
858
    if quant_dtype in [torch.float16, torch.bfloat16]:
        quant_dtype = None
    quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None
    if quant_input_only:
        assert a1_scale is None and a2_scale is None
        assert per_act_token_quant

bnellnm's avatar
bnellnm committed
859
    M, K = a.shape
860
    topk = topk_ids.shape[1]
bnellnm's avatar
bnellnm committed
861

862
863
864
865
    if apply_router_weights_on_input:
        assert topk == 1
        a = a * topk_weight.to(a.dtype)

bnellnm's avatar
bnellnm committed
866
867
868
869
    a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)

    out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)

870
871
    if a1_scale:
        assert not per_act_token_quant and block_shape is None
872
873
874
    a, a_scale = moe_kernel_quantize_input(
        a, a1_scale, quant_dtype, per_act_token_quant, block_shape
    )
bnellnm's avatar
bnellnm committed
875

876
877
878
    if quant_input_only:
        a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype)

bnellnm's avatar
bnellnm committed
879
880
    num_experts = w1.shape[0]

881
    topk_ids = topk_ids.view(-1)
882
883
    if expert_map is not None:
        topk_ids = expert_map[topk_ids]
bnellnm's avatar
bnellnm committed
884

885
886
    f32 = torch.float32

887
    act = op_registry[activation.custom_op_name]
888

bnellnm's avatar
bnellnm committed
889
    for i in range(num_experts):
890
891
        mask = topk_ids == i
        if mask.sum():
bnellnm's avatar
bnellnm committed
892
893
            if quant_dtype is None:
                tmp1 = a[mask] @ w1[i].transpose(0, 1)
894
895
                if b_bias1 is not None:
                    tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
896
                tmp2 = act()(tmp1)
bnellnm's avatar
bnellnm committed
897
                out[mask] = tmp2 @ w2[i].transpose(0, 1)
898
                if b_bias2 is not None:
899
                    out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
900
901
902
903
904
905
906
907
            elif quant_input_only:
                tmp1 = a[mask] @ w1[i].transpose(0, 1)
                tmp2 = SiluAndMul()(tmp1)
                tmp2, tmp2_scale = moe_kernel_quantize_input(
                    tmp2, None, quant_dtype, per_act_token_quant
                )
                tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype)
                out[mask] = tmp2 @ w2[i].transpose(0, 1)
bnellnm's avatar
bnellnm committed
908
            elif block_shape is not None:
909
                # block quantized
910
911
912
913
914
915
916
917
                assert (
                    a_scale is not None
                    and w1_scale is not None
                    and w2_scale is not None
                )
                tmp1 = native_w8a8_block_matmul(
                    a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, out.dtype
                )
918
919
                if b_bias1 is not None:
                    tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
bnellnm's avatar
bnellnm committed
920
921
                tmp2 = SiluAndMul()(tmp1)
                tmp2, b_scale = moe_kernel_quantize_input(
922
923
                    tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape
                )
bnellnm's avatar
bnellnm committed
924

925
926
927
                out[mask] = native_w8a8_block_matmul(
                    tmp2, w2[i], b_scale, w2_scale[i], block_shape, out.dtype
                )
928
                if b_bias2 is not None:
929
                    out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
bnellnm's avatar
bnellnm committed
930
            else:
931
932
933
934
935
                assert (
                    a_scale is not None
                    and w1_scale is not None
                    and w2_scale is not None
                )
bnellnm's avatar
bnellnm committed
936
                scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
937

bnellnm's avatar
bnellnm committed
938
939
                tmp1 = a[mask].to(f32) * scales
                w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
940
                tmp1 = (tmp1 @ w1_dq).to(out.dtype)
941
942
                if b_bias1 is not None:
                    tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype)
943
944
945
946

                tmp2 = SiluAndMul()(tmp1).to(out.dtype)

                tmp2, b_scale = moe_kernel_quantize_input(
947
948
                    tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape
                )
949
950
951
                assert b_scale is not None

                tmp2 = tmp2.to(f32) * b_scale
bnellnm's avatar
bnellnm committed
952
953
                w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
                out[mask] = (tmp2 @ w2_dq).to(out.dtype)
954
                if b_bias2 is not None:
955
                    out[mask] = out[mask] + b_bias2[i].view(1, -1).to(out.dtype)
bnellnm's avatar
bnellnm committed
956

957
958
959
    if apply_router_weights_on_input:
        return out
    else:
960
961
962
963
964
965
966
967
968
969
970
971
972
        return (
            (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
            .sum(dim=1)
            .to(out.dtype)
        )


def torch_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
973
974
    b_bias1: torch.Tensor | None = None,
    b_bias2: torch.Tensor | None = None,
975
    global_num_experts: int = -1,
976
    expert_map: torch.Tensor | None = None,
977
    activation: MoEActivation = MoEActivation.SILU,
978
) -> torch.Tensor:
979
980
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    topk_weight, topk_ids = torch.topk(score, topk)
981
982
983
984
985
986
987
988
989
990
    return torch_experts(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts,
        b_bias1,
        b_bias2,
        expert_map,
991
        activation=activation,
992
    )
993
994


995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
def torch_moe_single(a, w, score, topk):
    B, D = a.shape
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    _, topk_ids = torch.topk(score, topk)
    topk_ids = topk_ids.view(-1)
    for i in range(w.shape[0]):
        mask = topk_ids == i
        if mask.sum():
            out[mask] = a[mask] @ w[i].transpose(0, 1)
    return (out.view(B, -1, w.shape[1])).sum(dim=1)


1009
1010
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
1011
def opcheck(
1012
1013
1014
    op: torch._ops.OpOverload
    | torch._ops.OpOverloadPacket
    | torch._library.custom_ops.CustomOpDef,
1015
    args: tuple[Any, ...],
1016
    kwargs: dict[str, Any] | None = None,
1017
    *,
1018
    test_utils: str | Sequence[str] = ALL_OPCHECK_TEST_UTILS,
1019
1020
1021
    raise_exception: bool = True,
    cond: bool = True,
) -> dict[str, str]:
1022
    with patch("torch.allclose", new=fp8_allclose):
1023
1024
1025
1026
1027
1028
1029
        return (
            torch.library.opcheck(
                op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
            )
            if cond
            else {}
        )
1030
1031
1032
1033
1034


# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
1035
1036
1037
    return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
        dtype=torch.float8_e4m3fn
    )
1038
1039
1040
1041
1042
1043


def to_int8(tensor: torch.Tensor):
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


1044
1045
1046
1047
1048
1049
def baseline_scaled_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: type[torch.dtype],
1050
    bias: torch.Tensor | None = None,
1051
) -> torch.Tensor:
1052
    # We treat N-dimensional group scaling as extended numpy-style broadcasting
1053
    # in numpy simply stretches dimensions with an extent of 1 to match
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
    # the target shape by repeating the data along that dimension (broadcasting)
    # , we extend these semantics to say if the extent of a dimension in the
    # source shape is not 1 and does not match the target shape we repeat each
    # element along that dimension src_shape[dim] // target_shape[dim] times
    # example if we have:
    #       a = [[1, 2], and target_shape = (2, 4)
    #            [3, 4]]
    # then we would expand a to:
    #       a = [[1, 1, 2, 2],
    #            [3, 3, 4, 4]]
1064
    # NOTE this function does not explicitly broadcast dimensions
1065
1066
1067
1068
1069
    # with an extent of 1, since this can be done implicitly by pytorch
    def group_broadcast(t, shape):
        for i, s in enumerate(shape):
            if t.shape[i] != s and t.shape[i] != 1:
                assert s % t.shape[i] == 0
1070
1071
1072
1073
1074
                t = (
                    t.unsqueeze(i + 1)
                    .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
                    .flatten(i, i + 1)
                )
1075
1076
1077
1078
1079
        return t

    scale_a = group_broadcast(scale_a, a.shape)
    scale_b = group_broadcast(scale_b, b.shape)

1080
1081
1082
    output = torch.mm(
        (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
    ).to(out_dtype)
1083

1084
1085
1086
1087
    if bias is not None:
        output = output + bias

    return output