utils.py 44.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Kernel test utils"""

5
6
import itertools
import random
7
import unittest
8
from collections.abc import Sequence
9
from numbers import Number
10
from typing import Any, NamedTuple
11

12
import pytest
13
import torch
14
from torch._prims_common import TensorLikeType
15

bnellnm's avatar
bnellnm committed
16
from tests.kernels.quant_utils import native_w8a8_block_matmul
17
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
18
from vllm.attention.backends.registry import _Backend
19
from vllm.model_executor.layers.activation import SiluAndMul
20
21
22
23
24
25
26
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils import (
    STR_BACKEND_ENV_VAR,
    STR_FLASH_ATTN_VAL,
    STR_XFORMERS_ATTN_VAL,
    make_tensor_with_pad,
)
27

28
29
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
30
DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = (
31
32
33
34
35
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
)

36
ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
37
38
39
40
41
42
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
    "test_aot_dispatch_dynamic",
)

43

44
class QKVInputs(NamedTuple):
45
    """
46
    Data structure for representing unpacked attention inputs,
47
48
49
50
    query/key/values and their sequence lengths.

    Attributes:

51
        * {query,key,value}: unpacked (batch_size x padded_seq_len x
52
53
54
                             num_heads x head_size) attention inputs
        * q_seq_lens: query sequence lengths list
        * kv_seq_lens: shared key/value sequence lengths list
55
    """
56
57
58
59

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
60
61
    q_seq_lens: list[int]
    kv_seq_lens: list[int]
62
63
64


class QKVO(NamedTuple):
65
    """
66
    Data structure for representing unpacked attention inputs,
67
68
69
70
    alongside unpacked known-correct attention output

    Attributes:

71
        * qkv: unpacked (batch_size x padded_seq_len x
72
                             num_heads x head_size) attention inputs
73
        * ideal_output: unpacked (batch_size x padded_seq_len x
74
                        num_heads x head_size) known-correct attention output
75
    """
76
77
78
79
80
81

    qkv: QKVInputs
    ideal_output: torch.Tensor


class PackedQKVInputs(NamedTuple):
82
    """
83
84
85
86
    Data structure for representing packed attention inputs

    Attributes:

87
        * {query,key,value}: packed (number_of_tokens x num_heads
88
89
90
91
92
93
                             x head_size) attention inputs
        * q_start_loc_list: list of query start locations within packed tensor
        * kv_start_loc_list: shared list of key/value start locations within
                             packed tensor
        * q_seq_lens: query sequence lengths list
        * kv_seq_lens: shared key/value sequence lengths list
94
    """
95
96
97
98

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
99
100
101
102
    q_start_loc_list: list[int] | None
    kv_start_loc_list: list[int] | None
    q_seq_lens: list[int] | None
    kv_seq_lens: list[int] | None
103
104
105


class PackedQKVO(NamedTuple):
106
    """
107
    Data structure for representing packed attention inputs,
108
109
110
111
    alongside packed known-correct attention output

    Attributes:

112
        * packed_qkv: packed (number_of_tokens x num_heads
113
                      x head_size) attention inputs
114
        * ideal_output: packed (number_of_tokens x num_heads
115
                        x head_size) known-correct attention output
116
    """
117

118
    packed_qkv: PackedQKVInputs | None
119
120
121
122
    ideal_output: torch.Tensor


class KVMemoryMap(NamedTuple):
123
    """
124
125
126
127
128
129
    Data structure for encapsulating KV cache memory mapping.

    Attributes:

        * block_tables: KV cache block tables
        * slot_mapping: mapping of sequence offset to physical address
130
    """
131
132
133
134
135
136

    block_tables: torch.Tensor
    slot_mapping: torch.Tensor


class PhaseTestParameters(NamedTuple):
137
    """
138
139
140
141
142
143
    Data structure for encapsulating the test parameters
    for a given test "phase" (prefill or decode phase) and attention
    scenario (encoder, decoder-self, encoder/decoder-cross)

    Attributes:

144
        * packed_qkvo: packed (number_of_tokens x num_heads
145
146
147
148
                       x head_size) attention inputs & known-correct
                       output
        * kv_mmap: KV cache memory mapping, specific to this test phase &
                   attention scenario
149
    """
150
151

    packed_qkvo: PackedQKVO
152
    kv_mmap: KVMemoryMap | None
153
154
155


def maybe_make_int_tensor(
156
157
    _list: list[int] | None,
    device: torch.device | str,
158
) -> torch.Tensor:
159
    """
160
161
162
163
164
165
    Convert Python int list to a 1D int torch.Tensor on `device`

    Returns:

    * If _list is not None: 1D int torch.Tensor on `device`
    * None otherwise
166
167
168
169
    """
    return (
        None if _list is None else torch.tensor(_list, dtype=torch.int, device=device)
    )
170
171
172


def maybe_make_long_tensor(
173
174
    _list: list[int] | None,
    device: torch.device | str,
175
) -> torch.Tensor:
176
    """
177
178
179
180
181
182
    Convert Python int list to a 1D long torch.Tensor on `device`

    Returns:

    * If _list is not None: 1D long torch.Tensor on `device`
    * None otherwise
183
184
185
186
    """
    return (
        None if _list is None else torch.tensor(_list, dtype=torch.long, device=device)
    )
187
188


189
def maybe_max(_list: list | None) -> Number | None:
190
    """
191
192
193
194
    Returns:

    * If _list is not None: max(_list)
    * None otherwise
195
    """
196
197
198
199
200
201
202
    return None if _list is None else max(_list)


def make_causal_mask(
    q_max_seq_len: int,
    kv_max_seq_len: int,
) -> torch.Tensor:
203
    """
204
205
206
    Create a q_max_seq_len x kv_max_seq_len causal mask

    Arguments:
207

208
209
210
211
212
213
    * q_max_seq_len: query max seq len
    * kv_max_seq_len: key/value max seq len

    Returns:

    * 2D tensor, q_max_seq_len x kv_max_seq_len
214
    """
215
216
217
218

    # Create a matrix where entry (i, j) is True if i >= j
    mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
    # Replace True with float('-inf') and False with 0
219
    mask = mask.masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0)
220
221
222
    return mask


223
224
225
226
def override_backend_env_variable(
    mpatch: pytest.MonkeyPatch, backend_name: str
) -> None:
    """
227
228
229
230
231
232
233
234
    Override the environment variable indicating the vLLM backend temporarily,
    using pytest monkeypatch to ensure that the env vars get
    reset once the test context exits.

    Arguments:

    * mpatch: pytest monkeypatch instance
    * backend_name: attention backend name to force
235
    """
236
    mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
237
238


239
240
241
242
243
def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
244
245
246
    custom_mask: torch.Tensor | None = None,
    q_seq_lens: list | None = None,
    kv_seq_lens: list | None = None,
247
248
) -> torch.Tensor:
    """
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    "Golden" masked attention reference. Supports two types of masking:

    * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
      padding elements
    * Custom attention mask, which can force an arbitrary mask tensor, i.e.
      causal

    Arguments:

    * query: batch_size x q_padded_seq_len x num_heads x head_size
    * key: batch_size x kv_padded_seq_len x num_heads x head_size
    * value: batch_size x kv_padded_seq_len x num_heads x head_size
    * scale: Attention scale factor
    * custom_mask: custom attention mask; good place to inject a causal
      attention mask
    * q_seq_lens: list of unpadded query seq_lens for each batch index
    * kv_seq_lens: list of unpadded key/value seq_lens for each batch index

    Returns:

    * Attention result, batch_size x q_padded_seq_len x num_heads x head_size
270
    """
271
272
273
274
275

    assert q_seq_lens is not None
    assert kv_seq_lens is not None

    batch_size = query.shape[0]
276
277
    assert len(q_seq_lens) == batch_size
    assert len(kv_seq_lens) == batch_size
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304

    attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()

    # Basic attention mask, derived from seq lens
    if (q_seq_lens is not None) or (kv_seq_lens is not None):
        attn_mask = torch.zeros_like(attn_weights)
        if q_seq_lens is not None:
            for bdx, plen in enumerate(q_seq_lens):
                attn_mask[bdx, :, plen:, :] = -torch.inf
        if kv_seq_lens is not None:
            for bdx, plen in enumerate(kv_seq_lens):
                attn_mask[bdx, :, :, plen:] = -torch.inf

        attn_weights = attn_weights + attn_mask.float()

    # Custom attention mask
    if custom_mask is not None:
        attn_weights = attn_weights + custom_mask.float()

    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
    return out


def make_qkv(
    batch_size: int,
    max_q_seq_len: int,
305
    max_kv_seq_len: int | None,
306
307
    num_heads: int,
    head_size: int,
308
309
    device: torch.device | str,
    force_kv_seq_lens: list[int] | None = None,
310
311
    attn_type: AttentionType = AttentionType.ENCODER_DECODER,
    force_max_len: bool = False,
312
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
313
    """
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    Construct QKV test tensors for self- and cross-attention.

    Generates three query/key/value triplets:

    * "Baseline" query/key/value (for input to reference attention function)
    * "Prefill" query/key/value (last sequence offset zero'd out, for use as
      input to prefill kernel)
    * "Decode" query/key/value (only the last sequence offset  from baseline,
      for use as input to decode kernel)

    Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
    seqlens

    Arguments:

    * batch_size
    * max_q_seq_len: max query seq len
    * max_kv_seq_len: max key/value seq len
    * num_heads
    * head_size
334
335
336
    * is_encoder_decoder_attn: if True, query seqlen may differ from
      key/value seqlen (as is often the case for cross-attention);
      o/w, query/key/value seqlens match at each batch index
337
338
339
340
341
342
343
344
345
346
347
348
349
      (max_kv_seq_len is unused)
    * force_kv_seq_lens: if not None, overrides kv sequence lengths
    * attn_type: encoder, decoder self, or enc/dec cross attention
    * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
      seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
      and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
    * device: CPU or CUDA device

    Returns:

    * Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
    * Prefill QKVInputs structure (containing all but the last sequence offset)
    * Decode QKVInputs structure (containing all only the last sequence offset)
350
    """
351
352
353
354

    if force_max_len:
        q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
    else:
355
        q_seq_lens = [random.randint(2, max_q_seq_len) for _ in range(batch_size)]
356
357
358
359
360
361
362
363
364
365
366
367
    kv_seq_lens = None
    if force_kv_seq_lens is not None:
        kv_seq_lens = force_kv_seq_lens
    elif attn_type != AttentionType.ENCODER_DECODER:
        # K,V seq lens match Q for self-attention
        kv_seq_lens = q_seq_lens
    else:
        # K,V seq lens are distinct from Q seq lens & random
        assert max_kv_seq_len is not None
        if force_max_len:
            kv_seq_lens = [max_kv_seq_len] * batch_size
        else:
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
            kv_seq_lens = [random.randint(2, max_kv_seq_len) for _ in range(batch_size)]

    query = torch.rand((batch_size, max_q_seq_len, num_heads, head_size)).to(device)
    key = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
    value = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device)

    prefill_query = torch.zeros((batch_size, max_q_seq_len, num_heads, head_size)).to(
        device
    )
    prefill_key = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to(
        device
    )
    prefill_value = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to(
        device
    )

    decode_query = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
385
    decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
386
    decode_value = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
387

388
    for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)):
389
390
391
392
        query[bdx, q_seq_len:, :, :] = 0
        key[bdx, kv_seq_len:, :, :] = 0
        value[bdx, kv_seq_len:, :, :] = 0

393
394
395
396
397
398
399
400
401
402
403
404
405
        prefill_query[bdx, 0 : (q_seq_len - 1), :, :] = query[
            bdx, 0 : (q_seq_len - 1), :, :
        ]
        prefill_key[bdx, 0 : (kv_seq_len - 1), :, :] = key[
            bdx, 0 : (kv_seq_len - 1), :, :
        ]
        prefill_value[bdx, 0 : (kv_seq_len - 1), :, :] = value[
            bdx, 0 : (kv_seq_len - 1), :, :
        ]

        decode_query[bdx, :, :, :] = query[bdx, (q_seq_len - 1) : q_seq_len, :, :]
        decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1) : kv_seq_len, :, :]
        decode_value[bdx, :, :, :] = value[bdx, (kv_seq_len - 1) : kv_seq_len, :, :]
406
407
408
409
410
411
412
413
414
415
416
417
418

    prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
    prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]

    decode_q_seq_lens = [1 for _ in q_seq_lens]
    decode_kv_seq_lens = [1 for _ in kv_seq_lens]

    return (
        QKVInputs(
            query,  # Overall QKV inputs
            key,
            value,
            q_seq_lens,
419
420
            kv_seq_lens,
        ),
421
422
423
424
425
        QKVInputs(
            prefill_query,  # Prefill subset of QKV sequences
            prefill_key,
            prefill_value,
            prefill_q_seq_lens,
426
427
            prefill_kv_seq_lens,
        ),
428
429
430
431
432
        QKVInputs(
            decode_query,  # Decode subset of KV sequences
            decode_key,
            decode_value,
            decode_q_seq_lens,
433
434
435
            decode_kv_seq_lens,
        ),
    )
436
437
438


def pack_tensor(
439
    unpacked_tensor: torch.Tensor, seq_lens: list[int], device: torch.device | str
440
441
) -> tuple[torch.Tensor, list[int]]:
    """
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
    unpadded number_of_tokens x num_heads x head_size tensor, where
    number_of_tokens = sum(seq_lens)

    Arguments:

    * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
    * seq_lens: list of token counts for each seq
    * device: CPU or CUDA device

    Returns

    * packed_tensor: number_of_tokens x num_heads x head_size
    * start_loc_list: start idx of each batch elt in packed_tensor; [0] +
      list(itertools.accumulate(seq_lens))
457
    """
458
459
460
461
462
463
464
465

    num_tok = sum(seq_lens)
    num_heads = unpacked_tensor.shape[-2]
    head_size = unpacked_tensor.shape[-1]
    start_loc_list = [0] + list(itertools.accumulate(seq_lens))
    packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)

    for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
466
467
468
        packed_tensor[start_loc : (start_loc + seq_len), :, :] = unpacked_tensor[
            bdx, :seq_len, :, :
        ]
469
470
471
472

    return packed_tensor, start_loc_list


473
def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
474
    """
475
476
477
    Individually pack each of Q, K and V, each with dimensions batch_size x
    padded_seq_len x num_heads x head_size, into respective number_of_tokens x
    num_heads x head_size tensors.
478

479
480
481
482
483
484
485
486
487
488
489
490
491
492
    For Q, number_of_tokens = sum(q_seq_lens).

    For K and V, number_of_tokens = sum(kv_seq_lens)

    Arguments:

    * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
           attention inputs
    * device: CPU or CUDA device

    Returns

    * Packed (number_of_tokens x num_heads x head_size) QKV inputs
      derived from unpacked inputs
493
    """
494
495
496
497
498

    if qkv.query is None:
        packed_query = None
        q_start_loc_list = None
    else:
499
500
501
502
        packed_query, q_start_loc_list = pack_tensor(
            qkv.query, qkv.q_seq_lens, device=device
        )
    packed_key, kv_start_loc_list = pack_tensor(qkv.key, qkv.kv_seq_lens, device=device)
503
504
    packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
    return PackedQKVInputs(
505
506
507
508
        packed_query,
        packed_key,
        packed_value,
        q_start_loc_list,
509
510
        kv_start_loc_list,
        (None if q_start_loc_list is None else qkv.q_seq_lens),
511
512
        qkv.kv_seq_lens,
    )
513
514
515


def make_backend(backend_name: str) -> AttentionBackend:
516
    """
517
518
519
520
521
522
523
524
525
526
527
528
529
    Construct the backend instance determined by the backend_name string
    argument.

    Note: at time of writing the Attention wrapper automatically selects
    its own backend for Attention.forward(); so the backend instance which
    you generate with this function is not meant to be used for *running*
    inference, but rather for generating compatible metadata structures
    using backend.make_metadata()


    Returns:

    * Backend instance
530
    """
531
    if backend_name == STR_XFORMERS_ATTN_VAL:
532
533
        from vllm.v1.attention.backends.xformers import XFormersAttentionBackend

534
        return XFormersAttentionBackend()
535
    if backend_name == STR_FLASH_ATTN_VAL:
536
        from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
537

538
        return FlashAttentionBackend()
539
    if backend_name == "TRITON_ATTN":
540
541
        from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend

542
543
        return TritonAttentionBackend()
    if backend_name == "FLEX_ATTENTION":
544
545
        from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend

546
        return FlexAttentionBackend()
547
    if backend_name == "TORCH_SDPA":
548
        from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
549

550
551
552
        return TorchSDPABackend()
    if backend_name == "FLASHINFER":
        from vllm.v1.attention.backends.flashinfer import FlashInferBackend
553

554
        return FlashInferBackend()
555

556
    raise AssertionError(f"Unrecognized backend_name {backend_name} for unit test")
557
558


559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
def make_alibi_bias(
    alibi_slopes: torch.Tensor,
    num_kv_heads: int,
    dtype: torch.dtype,
    seq_lens: list[int],
) -> list[Any]:
    """Create ALiBi biases compatible with xFormers attention tests."""
    from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias

    if alibi_slopes is None:
        return [None for _ in seq_lens]

    attn_biases: list[Any] = []
    num_heads = alibi_slopes.shape[0]
    assert num_heads >= num_kv_heads, (
574
575
        "ALiBi slopes expect at least as many heads as KV heads"
    )
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595

    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
        bias = bias[None, :] - bias[:, None]

        padded_len = (seq_len + 7) // 8 * 8
        bias_tensor = torch.empty(
            1,
            num_heads,
            seq_len,
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
        )[:, :, :, :seq_len].copy_(bias)
        bias_tensor.mul_(alibi_slopes[:, None, None])
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))

    return attn_biases


596
def _make_metadata_tensors(
597
598
599
600
    seq_lens: list[int] | None,
    context_lens: list[int] | None,
    encoder_seq_lens: list[int] | None,
    device: torch.device | str,
601
602
603
604
605
) -> tuple[
    torch.Tensor,
    torch.Tensor,
    Any,
    Any,
606
    torch.Tensor | None,
607
608
    torch.Tensor,
    torch.Tensor,
609
    int | None,
610
611
]:
    """
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    Build scalar & tensor values required to build attention metadata structure.

    Arguments:

    * seq_lens: list of token-counts for each decoder input seq
    * context_lens: list of context length values for each seq
    * encoder_seq_lens: list of token-counts for each encoder input seq
    * device: CPU or CUDA device

    Returns:

    * seq_lens_tensor: decoder seq_lens list, as tensor
    * context_lens_tensor: context_lens list, as tensor
    * max_context_len: max(context_lens)
    * max_seq_len: max(seq_lens)
    * seq_start_loc: start idx of each sequence
628
629
    * encoder_seq_lens_tensor: encoder seq_lens list, as tensor
    * encoder_seq_start_loc: start idx of each encoder sequence
630
    * max_encoder_seq_len: encoder seq_lens list, as tensor
631
    """
632
633
634
635
636
637
    seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
    context_lens_tensor = maybe_make_int_tensor(context_lens, device)
    max_context_len = maybe_max(context_lens)
    max_seq_len = maybe_max(seq_lens)

    encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
638
    max_encoder_seq_len = None if encoder_seq_lens is None else max(encoder_seq_lens)
639
640
641

    seq_start_loc = None

642
    if seq_lens_tensor is not None:
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        seq_start_loc = torch.zeros(
            seq_lens_tensor.shape[0] + 1,
            dtype=torch.int32,
            device=seq_lens_tensor.device,
        )
        torch.cumsum(
            seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]
        )

    encoder_seq_start_loc = torch.zeros(
        encoder_seq_lens_tensor.shape[0] + 1,
        dtype=torch.int32,
        device=encoder_seq_lens_tensor.device,
    )
    torch.cumsum(
        encoder_seq_lens_tensor,
        dim=0,
        dtype=encoder_seq_start_loc.dtype,
        out=encoder_seq_start_loc[1:],
    )

    return (
        seq_lens_tensor,
        context_lens_tensor,
        max_context_len,
        max_seq_len,
        seq_start_loc,
        encoder_seq_lens_tensor,
        encoder_seq_start_loc,
        max_encoder_seq_len,
    )


def make_kv_cache(
    num_blocks: int,
    num_heads: int,
    head_size: int,
    block_size: int,
681
    device: torch.device | str,
682
683
684
685
    backend: str,
    default_val: float = 0.0,
) -> torch.Tensor:
    """
686
687
688
689
690
691
692
693
694
695
696
697
698
699
    Create a fake KV cache.

    Arguments:

    * num_blocks: number of blocks in the KV cache
    * num_heads: number of attention heads
    * head_size: head dimension
    * block_size: number of offsets within a block
    * device: CPU or CUDA device
    * default_val: initialization value for KV cache elements

    Returns:

    * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
700
    *     for backend 'XFORMERS'
701
    * kv_cache: 2 x num_blocks x block_size x num_heads x head_size
702
    *     for backend 'FLASH_ATTN'
703
704
705
706
707
708
709
710
711
    """
    if backend == "XFORMERS":
        kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(
            device
        )
    elif backend == "FLASH_ATTN":
        kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(
            device
        )
712
713
    else:
        raise ValueError(
714
715
            f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
        )
716
717
718
719
720
721
    if default_val is not None:
        kv_cache[:, :, :] = default_val
    return kv_cache


def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
722
    """
723
724
    Compute the minimum number of blocks required to hold num_tokens tokens,
    given block_size
725
    """
726
727
728
    return (num_tokens + block_size) // block_size


729
def make_empty_slot_mapping_tensor(device: torch.device | str):
730
731
732
    return maybe_make_long_tensor([], device)


733
def make_empty_block_tables_tensor(device: torch.device | str):
734
735
736
    return torch.tensor([], device=device)


737
738
739
def split_slot_mapping(
    slot_mapping_list: torch.Tensor,
    seq_lens: list[int],
740
    device: torch.device | str,
741
742
):
    """
743
744
745
746
747
    Split a slot mapping into valid prefill- and decode-phase slot mappings.

    Context:
    * Your goal is to test (1) prefill of N prompts, with prompt-lengths
      {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
748
      for all N prompts (N tokens total); the resultant sequence lengths
749
      after decode would be {K_i + 1 for i \\in [0,N)}
750
751
752
    * The test you want to do requires (1) having the prefill slot mapping
      for all tokens present during prefill, the number of which is
      M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
753
      decoded tokens
754
755

    This function consumes a single 1D slot mapping, which is the
756
757
758
759
760
    concatenation of N slot mappings each of length K_i + 1 (corresponding
    to the  sequence lengths after decode), with a total length of
    P = \\sum_i{K_i + 1} = M + N

    The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
761
    from each of the N subsequences in the slot mapping (i.e. omitting the
762
763
764
765
766
767
    decoded token's mapping.)

    The N excised entries are appended to obtain the decode-phase slot mapping

    Arguments:

768
    * slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
769
      post-decode sequences
770
    * seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
771
772
773
774
775
      description above)
    * device: cuda, cpu, etc.

    Returns:

776
    * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
777
      reflecting all N prefill prompts
778
    * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
779
      all N decoded tokens
780
    """
781
782
783
784
785
786

    prefill_slot_mapping = []
    decode_slot_mapping = []

    base_idx = 0
    for seq_len in seq_lens:
787
788
789
        prefill_slot_mapping.extend(
            slot_mapping_list[base_idx : (base_idx + seq_len - 1)]
        )
790
791
792
        decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
        base_idx += seq_len

793
794
795
796
    return (
        maybe_make_long_tensor(prefill_slot_mapping, device),
        maybe_make_long_tensor(decode_slot_mapping, device),
    )
797
798
799


def make_block_tables_slot_mapping(
800
801
    block_size: int,
    seq_lens: list[int],
802
    device: torch.device | str,
803
804
805
    block_base_addr: int = 0,
) -> tuple[torch.Tensor, list[int], int]:
    """
806
807
808
809
810
811
812
813
814
    Construct fake block tables & slot mappings.

    For a sequence with num_tokens tokens the minimum number
    of required KV cache blocks is

    num_blocks = (num_tokens + block_size) // block_size

    Then the minimum KV cache size in blocks is

815
    total_cache_blocks = sum(num_blocks for all seqs)
816
817
818
819
820
821
822
823

    Then, the blocktable mapping counts downward from

    block_base_addr + total_cache_blocks

    to

    block_base_addr
824

825
826
827
828
829
830
831
832
833
834
835
836
837
838

    The constructed block-tables and slot-mapping are sized to the
    lengths of the sequences in their entirety (as reflected by seq_lens),
    i.e. the total of prefill prompt tokens + decoded tokens.

    Arguments:

    * block_size: number of offsets per block
    * seq_lens: list of token-counts for each sequence
    * block_base_addr: the block table base address
    * device: CPU or CUDA device

    Return:

839
    * block_tables_tensor: block table for sequence
840
841
    * slot_mapping_list: slot mapping for sequence
    * max_block_idx: the highest block address within this block table
842
    """
843
844
845

    # Provision minimum number of KV cache blocks
    num_blocks_list = [
846
        _num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens
847
848
849
850
851
852
853
854
855
856
857
858
    ]
    max_block_table_len = max(num_blocks_list)
    block_table_pad_tokens = 10

    block_tables = []
    slot_mapping_list = []
    # Compute uppermost address of block table
    total_cache_blocks = sum(num_blocks_list)
    block_base_idx = block_base_addr + total_cache_blocks
    max_block_idx = block_base_idx
    for sdx, num_tokens in enumerate(seq_lens):
        num_blocks = num_blocks_list[sdx]
859
        block_table = list(range(block_base_idx, block_base_idx - num_blocks, -1))
860
        for idx in range(num_tokens):
861
862
863
            mapping_value = (idx % block_size) + block_table[
                idx // block_size
            ] * block_size
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
            slot_mapping_list.append(mapping_value)

        block_base_idx -= num_blocks
        block_tables.append(block_table)

    block_tables_tensor = make_tensor_with_pad(
        block_tables,
        max_len=max_block_table_len + block_table_pad_tokens,
        pad=0,
        dtype=torch.int,
        device=device,
    )

    return (block_tables_tensor, slot_mapping_list, max_block_idx)


def make_test_metadata(
881
    attn_backend: _Backend,
882
    is_prompt: bool,
883
884
885
886
887
    seq_lens: list[int] | None,
    decoder_test_params: PhaseTestParameters | None,
    device: torch.device | str,
    encoder_test_params: PhaseTestParameters | None = None,
    cross_test_params: PhaseTestParameters | None = None,
888
) -> AttentionMetadata:
889
    """
890
891
892
893
894
895
    Construct fake attention metadata for a given test phase
    (prefill-phase or decode-phase).

    encoder_test_params and cross_test_params arguments allow encoder
    attention and enc/dec cross-attention (respectively) to use distinct
    metadata values from decoder self-attention (decoder_test_params.)
896

897
898
899
900
901
902
903
904
905
    if encoder_test_params and cross_test_params are None, the attention
    metadata will support decoder-only scenario.

    Assumptions:

    * No chunked prefill -> a batch is 100% prefill or 100% decode, never both

    Arguments:

906
    * attn_backend_name: Backend for sourcing attention kernels
907
908
    * is_prompt: prefill if True, o/w decode
    * seq_lens: list of token counts for each sequence
909
    * decoder_test_params: decoder self-attention test params;
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
                           this function requires
                           kv_mmap (memory mapping) field
    * device: CPU or CUDA device
    * encoder_test_params: encoder attention test params;
                           this function requires encoder query
                           sequence lengths field. If None,
                           encoder query sequence lengths are
                           treated as None
    * cross_test_params: enc/dec cross-attention test params;
                         this function requires kv_mmap field.
                         If None, KV cache memory map data
                         structures are treated as None

    Return:

    * AttentionMetadata structure
926
    """
927
928
929
930

    # Decoder self-attention memory mapping
    # decoder_test_params is None signals encoder-only
    # scenario, so kv_mmap is None
931
    kv_mmap = None if decoder_test_params is None else decoder_test_params.kv_mmap
932
933
934
935
936
937
938
939
940
941
942
943

    # This function constructs metadata assuming no chunked prefill,
    # i.e. 100% prefill tokens or 100% decode tokens
    #
    # - If is_prompt, num_prefills_or_decodes is the number of prefills
    #   and num_prefill_or_decode_tokens is the number of prefill tokens
    # - If not is_prompt, num_prefills_or_decodes is the number of decodes
    #   and num_prefill_or_decode_tokens is the number of decode tokens
    #
    # seq_lens is None signals encoder-only
    # scenario, in which case num_prefills_or_decodes and
    # num_prefill_or_decode_tokens are unused
944
    num_prefills_or_decodes = None if seq_lens is None else len(seq_lens)
945

946
947
948
    num_prefill_or_decode_tokens = (
        None if seq_lens is None else (sum(seq_lens) if is_prompt else len(seq_lens))
    )
949
950
951
952
953
954
955
956
957
958
959
960
961

    # Seems for non-prefix-caching scenarios context_lens
    # is never needed
    context_lens = None

    if encoder_test_params is None:
        encoder_seq_lens = None
        num_encoder_tokens = None
    else:
        # Encoder/decoder or encoder-only models only:
        # * Extract encoder input sequence lengths
        assert encoder_test_params.packed_qkvo.packed_qkv is not None
        encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
962
963
964
        num_encoder_tokens = (
            None if encoder_seq_lens is None else (sum(encoder_seq_lens))
        )
965

966
967
968
    # For encoder/decoder or encoder-only models only, extract *cross-attention*
    # slot_mapping and block table (kv_mmap)
    cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap
969

970
971
    attn_backend_obj = make_backend(attn_backend.name)

972
973
974
975
976
977
978
979
980
981
982
983
    if is_prompt:
        # Prefill-phase scenario

        num_prefills = num_prefills_or_decodes
        num_prefill_tokens = num_prefill_or_decode_tokens
        num_decode_tokens = 0

        (
            seq_lens_tensor,
            context_lens_tensor,
            _,
            _,
984
            seq_start_loc,
985
            encoder_seq_lens_tensor,
986
            encoder_seq_start_loc,
987
            max_encoder_seq_len,
988
989
990
        ) = _make_metadata_tensors(
            seq_lens, context_lens, encoder_seq_lens, device=device
        )
991
        return attn_backend_obj.make_metadata(
992
993
            num_prefills=num_prefills,
            slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
994
            enable_kv_scales_calculation=True,
995
996
997
998
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
999
            seq_start_loc=seq_start_loc,
1000
1001
1002
1003
1004
1005
1006
1007
            max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
            max_decode_seq_len=0,
            context_lens_tensor=context_lens_tensor,
            block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
            use_cuda_graph=False,
            num_encoder_tokens=num_encoder_tokens,
            encoder_seq_lens=encoder_seq_lens,
            encoder_seq_lens_tensor=encoder_seq_lens_tensor,
1008
            encoder_seq_start_loc=encoder_seq_start_loc,
1009
            max_encoder_seq_len=max_encoder_seq_len,
1010
1011
1012
1013
1014
1015
1016
            cross_slot_mapping=(
                None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping
            ),
            cross_block_tables=(
                None if cross_kv_mmap is None else cross_kv_mmap.block_tables
            ),
        )
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033

    else:  # not is_prompt
        # Decode-phase scenario

        assert kv_mmap is not None
        assert num_prefill_or_decode_tokens is not None
        assert seq_lens is not None

        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = num_prefill_or_decode_tokens

        (
            seq_lens_tensor,
            context_lens_tensor,
            _,
            _,
1034
            seq_start_loc,
1035
            encoder_seq_lens_tensor,
1036
            encoder_seq_start_loc,
1037
            max_encoder_seq_len,
1038
1039
1040
        ) = _make_metadata_tensors(
            seq_lens, context_lens, encoder_seq_lens, device=device
        )
1041

1042
        return attn_backend_obj.make_metadata(
1043
1044
            num_prefills=num_prefills,
            slot_mapping=kv_mmap.slot_mapping,
1045
            enable_kv_scales_calculation=True,
1046
1047
1048
1049
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
1050
            seq_start_loc=seq_start_loc,
1051
1052
            max_prefill_seq_len=0,
            max_decode_seq_len=max(seq_lens),
1053
            max_decode_query_len=1,
1054
1055
1056
1057
1058
1059
            context_lens_tensor=context_lens_tensor,
            block_tables=kv_mmap.block_tables,
            use_cuda_graph=False,
            num_encoder_tokens=num_encoder_tokens,
            encoder_seq_lens=encoder_seq_lens,
            encoder_seq_lens_tensor=encoder_seq_lens_tensor,
1060
            encoder_seq_start_loc=encoder_seq_start_loc,
1061
            max_encoder_seq_len=max_encoder_seq_len,
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
            cross_slot_mapping=(
                None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping
            ),
            cross_block_tables=(
                None if cross_kv_mmap is None else cross_kv_mmap.block_tables
            ),
        )


def assert_actual_matches_ideal(
    test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str
) -> None:
    """
1075
1076
1077
1078
1079
1080
1081
    Assert that observed output matches the ideal output
    contained in the test parameters data structure.

    Arguments:

    * test_params: Test parameters including packed ideal output
    * output_under_test: actually observed output value
1082
    """
1083
    ideal_output = test_params.packed_qkvo.ideal_output
1084
1085
1086
1087
    if backend == "XFORMERS":
        torch.testing.assert_close(
            ideal_output, output_under_test.view_as(ideal_output)
        )
1088

1089
    elif backend == "FLASH_ATTN":
1090
1091
1092
        # For FlashAttention override the accuracy thresholds to non default
        # values since we notice a higher difference between the ideal and
        # actual output.
1093
1094
1095
        torch.testing.assert_close(
            ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
        )
1096
1097
    else:
        raise ValueError(
1098
1099
            f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
        )
1100
1101


1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    equal_nan: bool = False,
) -> bool:
    """
    Reference implementation of torch.allclose
    """
1113
    torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
1114
1115
1116

    return bool(
        torch.all(
1117
1118
1119
1120
1121
            torch.isclose(
                a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
            )
        ).item()
    )
1122
1123


1124
1125
1126
# Marlin MoE test utils


1127
def stack_and_dev(tensors: list[torch.Tensor]):
1128
1129
1130
1131
1132
1133
    dev = tensors[0].device
    return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
1134
1135
        torch.abs(output_ref)
    )
1136
1137


bnellnm's avatar
bnellnm committed
1138
1139
1140
1141
1142
1143
1144
def torch_experts(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    global_num_experts: int = -1,
1145
1146
1147
1148
1149
1150
1151
1152
    b_bias1: torch.Tensor | None = None,
    b_bias2: torch.Tensor | None = None,
    expert_map: torch.Tensor | None = None,
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
1153
    per_act_token_quant=False,
1154
    block_shape: list[int] | None = None,
1155
    apply_router_weights_on_input: bool = False,
bnellnm's avatar
bnellnm committed
1156
) -> torch.Tensor:
1157
1158
1159
1160
1161
    assert (
        global_num_experts == -1
        or (global_num_experts == w1.shape[0] and expert_map is None)
        or (expert_map is not None and global_num_experts == expert_map.shape[0])
    )
bnellnm's avatar
bnellnm committed
1162
1163

    M, K = a.shape
1164
    topk = topk_ids.shape[1]
bnellnm's avatar
bnellnm committed
1165

1166
1167
1168
1169
    if apply_router_weights_on_input:
        assert topk == 1
        a = a * topk_weight.to(a.dtype)

bnellnm's avatar
bnellnm committed
1170
1171
1172
1173
    a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)

    out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)

1174
1175
    if a1_scale:
        assert not per_act_token_quant and block_shape is None
1176
1177
1178
    a, a_scale = moe_kernel_quantize_input(
        a, a1_scale, quant_dtype, per_act_token_quant, block_shape
    )
bnellnm's avatar
bnellnm committed
1179
1180
1181

    num_experts = w1.shape[0]

1182
    topk_ids = topk_ids.view(-1)
1183
1184
    if expert_map is not None:
        topk_ids = expert_map[topk_ids]
bnellnm's avatar
bnellnm committed
1185

1186
1187
    f32 = torch.float32

bnellnm's avatar
bnellnm committed
1188
    for i in range(num_experts):
1189
1190
        mask = topk_ids == i
        if mask.sum():
bnellnm's avatar
bnellnm committed
1191
1192
            if quant_dtype is None:
                tmp1 = a[mask] @ w1[i].transpose(0, 1)
1193
1194
                if b_bias1 is not None:
                    tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
bnellnm's avatar
bnellnm committed
1195
1196
                tmp2 = SiluAndMul()(tmp1)
                out[mask] = tmp2 @ w2[i].transpose(0, 1)
1197
                if b_bias2 is not None:
1198
                    out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
bnellnm's avatar
bnellnm committed
1199
            elif block_shape is not None:
1200
                # block quantized
1201
1202
1203
1204
1205
1206
1207
1208
                assert (
                    a_scale is not None
                    and w1_scale is not None
                    and w2_scale is not None
                )
                tmp1 = native_w8a8_block_matmul(
                    a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, out.dtype
                )
1209
1210
                if b_bias1 is not None:
                    tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
bnellnm's avatar
bnellnm committed
1211
1212
                tmp2 = SiluAndMul()(tmp1)
                tmp2, b_scale = moe_kernel_quantize_input(
1213
1214
                    tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape
                )
bnellnm's avatar
bnellnm committed
1215

1216
1217
1218
                out[mask] = native_w8a8_block_matmul(
                    tmp2, w2[i], b_scale, w2_scale[i], block_shape, out.dtype
                )
1219
                if b_bias2 is not None:
1220
                    out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
bnellnm's avatar
bnellnm committed
1221
            else:
1222
1223
1224
1225
1226
                assert (
                    a_scale is not None
                    and w1_scale is not None
                    and w2_scale is not None
                )
bnellnm's avatar
bnellnm committed
1227
                scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
1228

bnellnm's avatar
bnellnm committed
1229
1230
                tmp1 = a[mask].to(f32) * scales
                w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
1231
                tmp1 = (tmp1 @ w1_dq).to(out.dtype)
1232
1233
                if b_bias1 is not None:
                    tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype)
1234
1235
1236
1237

                tmp2 = SiluAndMul()(tmp1).to(out.dtype)

                tmp2, b_scale = moe_kernel_quantize_input(
1238
1239
                    tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape
                )
1240
1241
1242
                assert b_scale is not None

                tmp2 = tmp2.to(f32) * b_scale
bnellnm's avatar
bnellnm committed
1243
1244
                w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
                out[mask] = (tmp2 @ w2_dq).to(out.dtype)
1245
                if b_bias2 is not None:
1246
                    out[mask] = out[mask] + b_bias2[i].view(1, -1).to(out.dtype)
bnellnm's avatar
bnellnm committed
1247

1248
1249
1250
    if apply_router_weights_on_input:
        return out
    else:
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
        return (
            (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
            .sum(dim=1)
            .to(out.dtype)
        )


def torch_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
1264
1265
    b_bias1: torch.Tensor | None = None,
    b_bias2: torch.Tensor | None = None,
1266
    global_num_experts: int = -1,
1267
    expert_map: torch.Tensor | None = None,
1268
) -> torch.Tensor:
1269
1270
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    topk_weight, topk_ids = torch.topk(score, topk)
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
    return torch_experts(
        a,
        w1,
        w2,
        topk_weight,
        topk_ids,
        global_num_experts,
        b_bias1,
        b_bias2,
        expert_map,
    )
1282
1283


1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
def torch_moe_single(a, w, score, topk):
    B, D = a.shape
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    _, topk_ids = torch.topk(score, topk)
    topk_ids = topk_ids.view(-1)
    for i in range(w.shape[0]):
        mask = topk_ids == i
        if mask.sum():
            out[mask] = a[mask] @ w[i].transpose(0, 1)
    return (out.view(B, -1, w.shape[1])).sum(dim=1)


1298
1299
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
1300
def opcheck(
1301
1302
1303
    op: torch._ops.OpOverload
    | torch._ops.OpOverloadPacket
    | torch._library.custom_ops.CustomOpDef,
1304
    args: tuple[Any, ...],
1305
    kwargs: dict[str, Any] | None = None,
1306
    *,
1307
    test_utils: str | Sequence[str] = ALL_OPCHECK_TEST_UTILS,
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
    raise_exception: bool = True,
    cond: bool = True,
) -> dict[str, str]:
    with unittest.mock.patch("torch.allclose", new=fp8_allclose):
        return (
            torch.library.opcheck(
                op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
            )
            if cond
            else {}
        )
1319
1320
1321
1322
1323


# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
1324
1325
1326
    return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
        dtype=torch.float8_e4m3fn
    )
1327
1328
1329
1330
1331
1332


def to_int8(tensor: torch.Tensor):
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


1333
1334
1335
1336
1337
1338
def baseline_scaled_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: type[torch.dtype],
1339
    bias: torch.Tensor | None = None,
1340
) -> torch.Tensor:
1341
    # We treat N-dimensional group scaling as extended numpy-style broadcasting
1342
    # in numpy simply stretches dimensions with an extent of 1 to match
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
    # the target shape by repeating the data along that dimension (broadcasting)
    # , we extend these semantics to say if the extent of a dimension in the
    # source shape is not 1 and does not match the target shape we repeat each
    # element along that dimension src_shape[dim] // target_shape[dim] times
    # example if we have:
    #       a = [[1, 2], and target_shape = (2, 4)
    #            [3, 4]]
    # then we would expand a to:
    #       a = [[1, 1, 2, 2],
    #            [3, 3, 4, 4]]
1353
    # NOTE this function does not explicitly broadcast dimensions
1354
1355
1356
1357
1358
    # with an extent of 1, since this can be done implicitly by pytorch
    def group_broadcast(t, shape):
        for i, s in enumerate(shape):
            if t.shape[i] != s and t.shape[i] != 1:
                assert s % t.shape[i] == 0
1359
1360
1361
1362
1363
                t = (
                    t.unsqueeze(i + 1)
                    .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
                    .flatten(i, i + 1)
                )
1364
1365
1366
1367
1368
        return t

    scale_a = group_broadcast(scale_a, a.shape)
    scale_b = group_broadcast(scale_b, b.shape)

1369
1370
1371
    output = torch.mm(
        (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
    ).to(out_dtype)
1372

1373
1374
1375
1376
    if bias is not None:
        output = output + bias

    return output