_custom_ops.py 94.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Literal
5
6
7

import torch

8
import vllm.envs as envs
9
from vllm.logger import init_logger
10
from vllm.platforms import current_platform
11
from vllm.scalar_type import ScalarType
12
13
14
from vllm.utils.flashinfer import (
    flashinfer_quant_nvfp4_8x4_sf_layout,
)
15
from vllm.utils.math_utils import cdiv
16
17
18

logger = init_logger(__name__)

19
current_platform.import_kernels()
20

21
if TYPE_CHECKING:
22
23
24
25
26
27
28
29
30

    def register_fake(fn):
        return lambda name: fn
else:
    try:
        from torch.library import register_fake
    except ImportError:
        from torch.library import impl_abstract as register_fake

31
32
33
34
35
36
37
38
39
40

# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
41
    seq_lens: torch.Tensor,
42
    block_size: int,
43
    max_seq_len: int,
44
    alibi_slopes: torch.Tensor | None,
45
    kv_cache_dtype: str,
46
47
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
48
49
50
51
52
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
53
) -> None:
54
    torch.ops._C.paged_attention_v1(
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        out,
        query,
        key_cache,
        value_cache,
        num_kv_heads,
        scale,
        block_tables,
        seq_lens,
        block_size,
        max_seq_len,
        alibi_slopes,
        kv_cache_dtype,
        k_scale,
        v_scale,
        tp_rank,
        blocksparse_local_blocks,
        blocksparse_vert_stride,
        blocksparse_block_size,
        blocksparse_head_sliding_step,
    )
75
76
77
78
79
80
81
82
83
84
85
86
87


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
88
    seq_lens: torch.Tensor,
89
    block_size: int,
90
    max_seq_len: int,
91
    alibi_slopes: torch.Tensor | None,
92
    kv_cache_dtype: str,
93
94
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
95
96
97
98
99
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
100
) -> None:
101
    torch.ops._C.paged_attention_v2(
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        out,
        exp_sum,
        max_logits,
        tmp_out,
        query,
        key_cache,
        value_cache,
        num_kv_heads,
        scale,
        block_tables,
        seq_lens,
        block_size,
        max_seq_len,
        alibi_slopes,
        kv_cache_dtype,
        k_scale,
        v_scale,
        tp_rank,
        blocksparse_local_blocks,
        blocksparse_vert_stride,
        blocksparse_block_size,
        blocksparse_head_sliding_step,
    )
125
126


127
128
129
130
131
132
133
134
135
136
137
138
def paged_attention_rocm(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
139
    query_start_loc: torch.Tensor | None,
140
141
    block_size: int,
    max_seq_len: int,
142
    alibi_slopes: torch.Tensor | None,
143
    kv_cache_dtype: str,
144
145
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
146
    fp8_out_scale: torch.Tensor | None = None,
xiao-llm's avatar
xiao-llm committed
147
    mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16",
148
) -> None:
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    torch.ops._rocm_C.paged_attention(
        out,
        exp_sum,
        max_logits,
        tmp_out,
        query,
        key_cache,
        value_cache,
        num_kv_heads,
        scale,
        block_tables,
        seq_lens,
        query_start_loc,
        block_size,
        max_seq_len,
        alibi_slopes,
        kv_cache_dtype,
        k_scale,
        v_scale,
        fp8_out_scale,
        mfma_type,
    )
171
172


Thien Tran's avatar
Thien Tran committed
173
174
175
176
177
178
179
180
def mla_decode_kvcache_cpu(
    out: torch.Tensor,
    query: torch.Tensor,
    kv_cache: torch.Tensor,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
) -> None:
181
182
183
    torch.ops._C_cpu.mla_decode_kvcache(
        out, query, kv_cache, scale, block_tables, seq_lens
    )
Thien Tran's avatar
Thien Tran committed
184
185


186
# merge attn states ops
187
188
189
190
191
192
def merge_attn_states(
    output: torch.Tensor,
    prefix_output: torch.Tensor,
    prefix_lse: torch.Tensor,
    suffix_output: torch.Tensor,
    suffix_lse: torch.Tensor,
193
    output_lse: torch.Tensor | None = None,
194
195
196
197
) -> None:
    torch.ops._C.merge_attn_states(
        output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse
    )
198
199


200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def convert_vertical_slash_indexes(
    q_seqlens: torch.Tensor,  # [BATCH, ]
    kv_seqlens: torch.Tensor,  # [BATCH, ]
    vertical_indexes: torch.Tensor,  # [BATCH, N_HEADS, NNZ_V]
    slash_indexes: torch.Tensor,  # [BATCH, N_HEADS, NNZ_S]
    context_size: int,
    block_size_M: int,
    block_size_N: int,
    causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    batch_size = slash_indexes.size(0)
    num_heads = slash_indexes.size(1)
    nnz_slash = slash_indexes.size(2)
    nnz_vertical = vertical_indexes.size(2)
    num_rows = (context_size + block_size_M - 1) // block_size_M

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    block_count = torch.zeros(
        batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
    )
    block_offset = torch.zeros(
        batch_size,
        num_heads,
        num_rows,
        nnz_slash,
        dtype=q_seqlens.dtype,
        device=q_seqlens.device,
    )
    column_count = torch.zeros(
        batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
    )
    column_index = torch.zeros(
        batch_size,
        num_heads,
        num_rows,
        nnz_vertical,
        dtype=q_seqlens.dtype,
        device=q_seqlens.device,
    )
238
239

    torch.ops._C.convert_vertical_slash_indexes(
240
241
242
243
244
245
246
247
248
249
250
251
252
        block_count,
        block_offset,
        column_count,
        column_index,
        q_seqlens,
        kv_seqlens,
        vertical_indexes,
        slash_indexes,
        context_size,
        block_size_M,
        block_size_N,
        causal,
    )
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    return block_count, block_offset, column_count, column_index


def convert_vertical_slash_indexes_mergehead(
    q_seqlens: torch.Tensor,  # [BATCH, ]
    kv_seqlens: torch.Tensor,  # [BATCH, ]
    vertical_indexes: torch.Tensor,  # [BATCH, N_HEADS, NNZ_V]
    slash_indexes: torch.Tensor,  # [BATCH, N_HEADS, NNZ_S]
    # [N_HEADS] : different head use different number of indices
    vertical_indices_count: torch.Tensor,
    slash_indices_count: torch.Tensor,
    context_size: int,
    block_size_M: int,
    block_size_N: int,
    causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    batch_size = slash_indexes.size(0)
    num_heads = slash_indexes.size(1)
    nnz_slash = slash_indexes.size(2)
    nnz_vertical = vertical_indexes.size(2)
    num_rows = (context_size + block_size_M - 1) // block_size_M

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    block_count = torch.empty(
        batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
    )
    block_offset = torch.empty(
        batch_size,
        num_heads,
        num_rows,
        nnz_slash,
        dtype=q_seqlens.dtype,
        device=q_seqlens.device,
    )
    column_count = torch.empty(
        batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
    )
    column_index = torch.empty(
        batch_size,
        num_heads,
        num_rows,
        nnz_vertical,
        dtype=q_seqlens.dtype,
        device=q_seqlens.device,
    )
297
298

    torch.ops._C.convert_vertical_slash_indexes_mergehead(
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        block_count,
        block_offset,
        column_count,
        column_index,
        q_seqlens,
        kv_seqlens,
        vertical_indexes,
        slash_indexes,
        vertical_indices_count,
        slash_indices_count,
        context_size,
        block_size_M,
        block_size_N,
        causal,
    )
314
315
316
    return block_count, block_offset, column_count, column_index


317
318
319
320
# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
321
    key: torch.Tensor | None,
322
323
324
325
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
326
327
328
    torch.ops._C.rotary_embedding(
        positions, query, key, head_size, cos_sin_cache, is_neox
    )
329
330
331


# layer norm ops
332
333
334
def rms_norm(
    out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> None:
335
    torch.ops._C.rms_norm(out, input, weight, epsilon)
336
337


338
339
340
def fused_add_rms_norm(
    input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> None:
341
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
342
343


344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def fused_qk_norm_rope(
    qkv: torch.Tensor,
    num_heads_q: int,
    num_heads_k: int,
    num_heads_v: int,
    head_dim: int,
    eps: float,
    q_weight: torch.Tensor,
    k_weight: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
    position_ids: torch.Tensor,
) -> None:
    torch.ops._C.fused_qk_norm_rope(
        qkv,
        num_heads_q,
        num_heads_k,
        num_heads_v,
        head_dim,
        eps,
        q_weight,
        k_weight,
        cos_sin_cache,
        is_neox,
        position_ids,
    )


372
def apply_repetition_penalties_torch(
373
374
375
376
377
    logits: torch.Tensor,
    prompt_mask: torch.Tensor,
    output_mask: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> None:
378
    repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
379
380
        1, logits.size(1)
    )
381
    # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
382
    penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
383
384
385
386
387
388
    # If logits are positive, divide by penalty, otherwise multiply by penalty.
    scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
    logits *= scaling


def apply_repetition_penalties_cuda(
389
390
391
392
393
394
395
396
    logits: torch.Tensor,
    prompt_mask: torch.Tensor,
    output_mask: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> None:
    torch.ops._C.apply_repetition_penalties_(
        logits, prompt_mask, output_mask, repetition_penalties
    )
397
398


399
400
401
402
403
404
def apply_repetition_penalties(
    logits: torch.Tensor,
    prompt_mask: torch.Tensor,
    output_mask: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> None:
405
406
407
408
409
410
411
412
    """Apply repetition penalties to logits in-place.

    Args:
        logits: The logits tensor of shape [num_seqs, vocab_size].
        prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
        output_mask: A boolean tensor indicating which tokens appear in the output.
        repetition_penalties: The repetition penalties of shape (num_seqs, ).
    """
413
    if logits.is_cuda and logits.is_contiguous():
414
415
416
        apply_repetition_penalties_cuda(
            logits, prompt_mask, output_mask, repetition_penalties
        )
417
    else:
418
419
420
        apply_repetition_penalties_torch(
            logits, prompt_mask, output_mask, repetition_penalties
        )
421
422


423
424
425
426
427
428
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
429
430
    scale_ub: torch.Tensor | None = None,
    residual: torch.Tensor | None = None,
431
) -> tuple[torch.Tensor, torch.Tensor]:
432
    output = torch.empty_like(input, dtype=quant_dtype)
433
434
435
    scales = torch.empty(
        (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
    )
436

437
438
439
    torch.ops._C.rms_norm_dynamic_per_token_quant(
        output, input, weight, scales, epsilon, scale_ub, residual
    )
440
441
442
    return output, scales


443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
# fused quant layer norm ops blocked
def rms_norm_per_block_quant(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_size: list[int],
    scale_ub: torch.Tensor | None = None,
    residual: torch.Tensor | None = None,
    is_scale_transposed: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert len(group_size) == 2
    output = torch.empty_like(input, dtype=quant_dtype)
    if is_scale_transposed:
        scales = torch.empty(
            (input.shape[-1] // group_size[1], input.numel() // input.shape[-1]),
            device=input.device,
            dtype=torch.float32,
        ).transpose(0, 1)
    else:
        scales = torch.empty(
            (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]),
            device=input.device,
            dtype=torch.float32,
        )

    torch.ops._C.rms_norm_per_block_quant(
        output,
        input,
        weight,
        scales,
        epsilon,
        scale_ub,
        residual,
        group_size[1],
        is_scale_transposed,
    )
    return output, scales


483
484
# quantization ops
# awq
485
486
487
488
489
490
491
492
def awq_dequantize(
    qweight: torch.Tensor,
    scales: torch.Tensor,
    zeros: torch.Tensor,
    split_k_iters: int,
    thx: int,
    thy: int,
) -> torch.Tensor:
493
494
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
495
496
497
            awq_dequantize_triton,
        )

498
        return awq_dequantize_triton(qweight, scales, zeros)
499
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
500
501


502
503
504
505
def awq_gemm(
    input: torch.Tensor,
    qweight: torch.Tensor,
    scales: torch.Tensor,
506
    qzeros: torch.Tensor,
507
508
    split_k_iters: int,
) -> torch.Tensor:
509
    if envs.VLLM_USE_TRITON_AWQ:
510
511
        from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton

512
513
        return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
    return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
514
515
516


# gptq
517
518
519
520
521
522
523
def gptq_gemm(
    a: torch.Tensor,
    b_q_weight: torch.Tensor,
    b_gptq_qzeros: torch.Tensor,
    b_gptq_scales: torch.Tensor,
    b_g_idx: torch.Tensor,
    use_exllama: bool,
524
    use_v2_format: bool,
525
526
527
    bit: int,
) -> torch.Tensor:
    return torch.ops._C.gptq_gemm(
528
529
530
531
532
533
534
535
        a,
        b_q_weight,
        b_gptq_qzeros,
        b_gptq_scales,
        b_g_idx,
        use_exllama,
        use_v2_format,
        bit,
536
    )
537
538


539
if hasattr(torch.ops._C, "gptq_gemm"):
540

541
    @register_fake("_C::gptq_gemm")
542
543
544
545
546
547
548
    def _gptq_gemm_fake(
        a: torch.Tensor,
        b_q_weight: torch.Tensor,
        b_gptq_qzeros: torch.Tensor,
        b_gptq_scales: torch.Tensor,
        b_g_idx: torch.Tensor,
        use_exllama: bool,
549
        use_v2_format: bool,
550
551
552
553
554
        bit: int,
    ) -> torch.Tensor:
        return torch.empty(
            (a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device
        )
555
556


557
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
558
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
559
560


561
# marlin_24
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def gptq_marlin_24_gemm(
    a: torch.Tensor,
    b_q_weight: torch.Tensor,
    b_meta: torch.Tensor,
    b_scales: torch.Tensor,
    workspace: torch.Tensor,
    b_q_type: ScalarType,
    size_m: int,
    size_n: int,
    size_k: int,
) -> torch.Tensor:
    return torch.ops._C.gptq_marlin_24_gemm(
        a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
    )
576
577


578
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
579

580
    @register_fake("_C::gptq_marlin_24_gemm")
581
582
583
584
585
586
587
588
589
590
591
    def _gptq_marlin_24_gemm_fake(
        a: torch.Tensor,
        b_q_weight: torch.Tensor,
        b_meta: torch.Tensor,
        b_scales: torch.Tensor,
        workspace: torch.Tensor,
        b_q_type: ScalarType,
        size_m: torch.SymInt,
        size_n: torch.SymInt,
        size_k: torch.SymInt,
    ) -> torch.Tensor:
592
593
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

594
    @register_fake("_C::gptq_marlin_gemm")
595
596
    def _gptq_marlin_gemm_fake(
        a: torch.Tensor,
597
        c: torch.Tensor | None,
598
        b_q_weight: torch.Tensor,
599
        b_bias: torch.Tensor | None,
600
        b_scales: torch.Tensor,
601
        a_scales: torch.Tensor | None,
602
603
604
605
        global_scale: torch.Tensor | None,
        b_zeros: torch.Tensor | None,
        g_idx: torch.Tensor | None,
        perm: torch.Tensor | None,
606
607
608
609
610
611
612
613
614
615
        workspace: torch.Tensor,
        b_q_type_id: int,
        size_m: torch.SymInt,
        size_n: torch.SymInt,
        size_k: torch.SymInt,
        is_k_full: bool = True,
        use_atomic_add: bool = False,
        use_fp32_reduce: bool = False,
        is_zp_float: bool = False,
    ) -> torch.Tensor:
616
617
618
619
        dtype = a.dtype
        if dtype not in [torch.half, torch.bfloat16]:
            dtype = b_scales.dtype
        return torch.empty((size_m, size_n), device=a.device, dtype=dtype)
620

621
    @register_fake("_C::awq_dequantize")
622
623
624
625
626
627
628
629
    def _awq_dequantize_fake(
        qweight: torch.Tensor,
        scales: torch.Tensor,
        zeros: torch.Tensor,
        split_k_iters: torch.SymInt,
        thx: int,
        thy: int,
    ) -> torch.Tensor:
630
631
632
        in_c = qweight.size(0)
        qout_c = qweight.size(1)
        out_c = qout_c * 8
633
        return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device)
634

635
    @register_fake("_C::awq_gemm")
636
637
638
639
    def _awq_gemm_fake(
        input: torch.Tensor,
        qweight: torch.Tensor,
        scales: torch.Tensor,
640
        qzeros: torch.Tensor,
641
642
        split_k_iters: torch.SymInt,
    ) -> torch.Tensor:
643
        num_in_feats = input.size(0)
644
645
646
647
648
        return torch.empty(
            (split_k_iters, num_in_feats, qweight.size(1) * 8),
            dtype=input.dtype,
            device=input.device,
        ).sum(0)
649

650
651
    @register_fake("_C::machete_mm")
    def machete_mm_fake(
652
        a: torch.Tensor,
653
        # b_q Should be the tensor returned by machete_prepack_B
654
        b_q: torch.Tensor,
655
        b_type: ScalarType,
656
657
658
659
660
661
662
        out_type: torch.dtype | None = None,
        b_group_scales: torch.Tensor | None = None,
        b_group_zeros: torch.Tensor | None = None,
        b_group_size: int | None = None,
        b_channel_scales: torch.Tensor | None = None,
        a_token_scales: torch.Tensor | None = None,
        schedule: str | None = None,
663
664
665
666
667
    ) -> torch.Tensor:
        m = a.size(0)
        n = b_q.size(1)
        return torch.empty((m, n), device=a.device, dtype=a.dtype)

668
    @register_fake("_C::machete_prepack_B")
669
    def machete_prepack_B_fake(
670
671
672
        b_q_weight: torch.Tensor,
        a_type: torch.dtype,
        b_type: ScalarType,
673
        group_scales_type: torch.dtype | None,
674
675
    ) -> torch.Tensor:
        return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format)
676

677
678
    @register_fake("_C::cutlass_w4a8_mm")
    def cutlass_w4a8_mm_fake(
679
680
681
682
683
684
685
        a: torch.Tensor,
        # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
        b_q: torch.Tensor,
        b_group_scales: torch.Tensor,
        b_group_size: int,
        b_channel_scales: torch.Tensor,
        a_token_scales: torch.Tensor,
686
687
        out_type: torch.dtype | None = None,
        maybe_schedule: str | None = None,
688
    ) -> torch.Tensor:
689
690
691
692
693
694
695
696
697
698
699
700
701
        m = a.size(0)
        n = b_q.size(1)
        out_dtype = out_type if out_type is not None else torch.bfloat16
        return torch.empty((m, n), device=a.device, dtype=out_dtype)

    @register_fake("_C::cutlass_pack_scale_fp8")
    def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(scales, memory_format=torch.contiguous_format)

    @register_fake("_C::cutlass_encode_and_reorder_int4b")
    def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(b, memory_format=torch.contiguous_format)

702
703
704
705
    @register_fake("_C::cutlass_encode_and_reorder_int4b_grouped")
    def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(b, memory_format=torch.contiguous_format)

706

707
708
709
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):

    @register_fake("_C::allspark_w8a16_gemm")
710
711
712
713
    def _allspark_w8a16_gemm_fake(
        a: torch.Tensor,
        b_qweight: torch.Tensor,
        b_scales: torch.Tensor,
714
        b_qzeros: torch.Tensor | None,
715
716
717
718
719
720
721
722
        n: torch.SymInt,
        group_size: torch.SymInt,
        sm_count: torch.SymInt,
        sm_version: torch.SymInt,
        CUBLAS_M_THRESHOLD: torch.SymInt,
        has_zp: bool,
        n32k16_reorder: bool,
    ) -> torch.Tensor:
723
724
725
726
        m = a.size(0)
        return torch.empty((m, n), device=a.device, dtype=a.dtype)


727
728
729
if hasattr(torch.ops._C, "ggml_dequantize"):

    @register_fake("_C::ggml_dequantize")
730
    def _ggml_dequantize_fake(
731
732
733
734
        W: torch.Tensor,
        quant_type: int,
        m: torch.SymInt,
        n: torch.SymInt,
735
        dtype: torch.dtype | None = None,
736
    ) -> torch.Tensor:
737
738
739
740
741
742
743
744
745
        return torch.empty((m, n), dtype=torch.float16, device=W.device)

    @register_fake("_C::ggml_mul_mat_vec_a8")
    def _ggml_mul_mat_vec_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: torch.SymInt,
    ) -> torch.Tensor:
746
        return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device)
747
748
749
750
751
752
753
754
755

    @register_fake("_C::ggml_mul_mat_a8")
    def _ggml_mul_mat_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: torch.SymInt,
    ) -> torch.Tensor:
        batch = X.size(0)
756
        return torch.empty((batch, row), dtype=X.dtype, device=W.device)
757

758
759
760
761
762
763
764
765
766
767
768
769
770
    @register_fake("_C::ggml_moe_a8")
    def _ggml_moe_a8_fake(
        X: torch.Tensor,
        W: torch.Tensor,
        sorted_token_ids: torch.Tensor,
        expert_ids: torch.Tensor,
        num_tokens_post_padded: torch.Tensor,
        quant_type: int,
        row: torch.SymInt,
        top_k: torch.SymInt,
        tokens: torch.SymInt,
    ) -> torch.Tensor:
        tokens = X.size(0)
771
        return torch.empty((tokens * top_k, row), dtype=torch.float16, device=W.device)
772

773

774
775
776
777
778
779
780
781
782
783
784
785
786
if hasattr(torch.ops._C, "ggml_moe_a8_vec"):

    @register_fake("_C::ggml_moe_a8_vec")
    def _ggml_moe_a8_vec_fake(
        X: torch.Tensor,
        W: torch.Tensor,
        topk_ids: torch.Tensor,
        top_k: int,
        quant_type: int,
        row: torch.SymInt,
        tokens: torch.SymInt,
    ) -> torch.Tensor:
        tokens = X.size(0)
787
        return torch.empty((tokens * top_k, row), dtype=X.dtype, device=W.device)
788
789


790
# cutlass
791
792
793
794
def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)


795
796
797
798
799
800
801
802
def cutlass_scaled_fp4_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out_dtype: torch.dtype,
) -> torch.Tensor:
803
804
805
    assert a.ndim == 2 and b.ndim == 2
    m, n = a.shape[0], b.shape[0]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)
806
    torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, alpha)
807
808
809
    return out


810
811
812
813
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


814
def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
815
    return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability)
816
817


818
819
820
821
822
823
def cutlass_scaled_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
824
    bias: torch.Tensor | None = None,
825
) -> torch.Tensor:
826
    """
827
    `cutlass_scaled_mm` implements a fused version of
828
        `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
829
830
831
832
833
834
835
836
    where scale_a * a and scale_b * b are implemented using numpy-style
    broadcasting.

    In order to support blockwise scaling like found in DeepSeek V3 we also
    support extended "group" broadcast rules. We extend the numpy-style
    broadcasting rules with the following rule:
        "if the extent of a dimension in the source shape is between 1 and
        corresponding extent in the target shape we repeat each element along
837
838
839
840
841
842
843
844
845
846
847
        that dimension  src_shape[dim] // target_shape[dim] times consecutively"
    example if we have:
          a = [[1, 2], and target_shape = (2, 4)
               [3, 4]]
    then we would expand a to:
          a = [[1, 1, 2, 2],
               [3, 3, 4, 4]]
    currently we only support the case:
        scale_a.shape * [1, 128] == a.shape
        scale_b.shape * [128, 128] == b.shape
    """
848
849
    assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
    assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
850

851
852
853
    # Massage the input to be 2D
    target_shape = (*a.shape[:-1], b.shape[1])
    a = a.view(-1, a.shape[-1])
854

855
    cutlass_compatible_b = b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
856
    if current_platform.is_rocm() or not cutlass_compatible_b:
857
        from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import (  # noqa
858
859
860
            triton_scaled_mm,
        )

861
862
        out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    else:
863
        out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
864
        torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
865

866
    return out.view(*target_shape)
867
868


869
870
871
872
873
874
875
def cutlass_scaled_mm_azp(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
    azp_adj: torch.Tensor,
876
877
    azp: torch.Tensor | None = None,
    bias: torch.Tensor | None = None,
878
) -> torch.Tensor:
879
880
881
882
883
    """
    :param azp_adj: In the per-tensor case, this should include the azp.
    Always per-channel.
    :param azp: Only set in the per-token case. Per-token if set.
    """
884
885
886
    assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
    assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
    assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
887

888
889
890
891
    # Massage the input to be 2D
    target_shape = (*a.shape[:-1], b.shape[1])
    a = a.view(-1, a.shape[-1])
    assert azp is None or azp.numel() == a.shape[0]
892

893
894
    out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
895
    return out.view(*target_shape)
896
897


898
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
899
    return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability)
900
901


902
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
903
904
905
906
907
    try:
        return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)
    except AttributeError:
        # Return False on non-CUDA platforms where it is not available
        return False
908

909

910
def cutlass_sparse_compress(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
911
912
913
914
915
916
917
918
    """
    Compresses a sparse matrix for use with Cutlass sparse operations.

    This function takes a dense tensor and compresses it into two components:
    non-zero elements and metadata. The compressed representation is compatible
    with Cutlass sparse kernels.

    Args:
919
        a (torch.Tensor):
920
921
922
923
924
925
926
            The input tensor to be compressed. Must have one of the following data types:
            - `torch.int8`
            - `torch.float8_e4m3fn`
            - `torch.bfloat16`
            - `torch.float16`

    Returns:
927
        tuple[torch.Tensor, torch.Tensor]:
928
929
930
931
932
933
934
935
936
937
938
939
940
            A tuple containing:
            - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
            - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.

    Raises:
        ValueError: If the compression operation fails.

    Notes:
        - The `a_meta` tensor has a data type of `torch.uint8`.
        - Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
        - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
        - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
    """
941
942
    assert a.dtype in [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16]
    assert a.is_contiguous()
943
944
945

    # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
    elemsPerMetaElem = 4
946
    assert a.shape[1] % (2 * elemsPerMetaElem) == 0
947

948
    return torch.ops._C.cutlass_sparse_compress(a)
949
950
951


def cutlass_scaled_sparse_mm(
952
953
954
955
956
957
    a: torch.Tensor,
    bt_nzs: torch.Tensor,
    bt_meta: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
958
    bias: torch.Tensor | None = None,
959
) -> torch.Tensor:
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
    """
    Performs a scaled sparse matrix multiplication using Cutlass.

    Steps:
    1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
    `a = torch.randn((m, k), device='cuda')`.

    2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
    `b = torch.randn((k, n), device='cuda')`.

    3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
    `b = prune_to_2_4(b, dim=0)`.

    4. Compress the transposed sparse matrix `b.t()`:
    `bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.

    5. Perform sparse matrix multiplication using the compressed matrix,
    applying scaling factors for `a` and `b`, and the output data type:
    `out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.

    Returns:
    - The result of the scaled sparse matrix multiplication.
    """
983
984
985
    assert bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0
    assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
    assert bias is None or bias.shape[0] == bt_nzs.shape[0] and bias.dtype == out_dtype
986
987
988
989
990

    m = a.shape[0]
    n = bt_nzs.shape[0]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

991
992
993
    torch.ops._C.cutlass_scaled_sparse_mm(
        out, a, bt_nzs, bt_meta, scale_a, scale_b, bias
    )
994
995
996
997

    return out


998
999
1000
1001
1002
1003
1004
1005
1006
1007
def get_cutlass_moe_mm_data(
    topk_ids: torch.Tensor,
    expert_offsets: torch.Tensor,
    problem_sizes1: torch.Tensor,
    problem_sizes2: torch.Tensor,
    input_permutation: torch.Tensor,
    output_permutation: torch.Tensor,
    num_experts: int,
    n: int,
    k: int,
1008
    blockscale_offsets: torch.Tensor | None = None,
1009
):
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
    """
    Prepare data necessary to perform CUTLASS grouped matrix multiplications
    used in CUTLASS-based fused MoE.

    The function takes in topk_ids (token-expert mapping) and uses it to
    compute:
    - expert_offsets: Indices that mark at which token index each expert begins
                      its computation after the input is sorted with
                      input_permutation. The number of tokens computed with
                      expert E is expert_offsets[E + 1] - expert_offsets[E]
    - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
                                      multiplication in two grouped MMs used in
                                      the fused MoE operation.
    - input_permutation: Permutation that must be used to shuffle the input
                         before executing the MMs.
    - output_permutation: Permutation that must be used to shuffle the output
                          after executing the MMs.
1027
1028
1029
1030
1031
    - blockscale_offsets: Optional argument passed for fp4 moe. Indices that
                          mark at which block scale index each expert begins
                          its computation. The number of block scale rows
                          computed with expert E is blockscale_offsets[E + 1] -
                          blockscale_offsets[E]
1032
    """
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
    return torch.ops._C.get_cutlass_moe_mm_data(
        topk_ids,
        expert_offsets,
        problem_sizes1,
        problem_sizes2,
        input_permutation,
        output_permutation,
        num_experts,
        n,
        k,
        blockscale_offsets,
    )
1045
1046


1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
def get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
    expert_first_token_offset: torch.Tensor,
    problem_sizes1: torch.Tensor,
    problem_sizes2: torch.Tensor,
    n: int,
    k: int,
    swap_ab: bool,
):
    """Compute per-expert (M, N, K) problem sizes from expert_first_token_offset"""
    return torch.ops._C.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
        expert_first_token_offset,
        problem_sizes1,
        problem_sizes2,
        n,
        k,
        swap_ab,
    )


1066
1067
1068
1069
1070
1071
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
    """
    Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
    This is used in MoE to permute the input tensor before performing grouped matrix multiplications.
    """
    num_tokens_permuted = dst2src_map.shape[0]
1072
1073
1074
1075
1076
    output_tensor = torch.empty(
        (num_tokens_permuted, input_tensor.shape[1]),
        device=input_tensor.device,
        dtype=input_tensor.dtype,
    )
1077
1078
    torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor)
    return output_tensor
1079
1080


1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
def get_cutlass_pplx_moe_mm_data(
    expert_offsets: torch.Tensor,
    problem_sizes1: torch.Tensor,
    problem_sizes2: torch.Tensor,
    expert_num_tokens: torch.Tensor,
    num_local_experts: int,
    padded_m: int,
    n: int,
    k: int,
):
1091
1092
1093
1094
1095
    """
    Prepare data necessary to perform CUTLASS grouped matrix multiplications
    used in CUTLASS-based fused MoE.

    The function takes in expert_num_tokens (token count per expert) and
1096
    non_zero_expert_idxs (consecutive indices of experts with non-zero token
1097
1098
1099
1100
1101
1102
1103
1104
    counts) and uses them to compute:
    - expert_offsets: Indices that mark at which token index each expert begins
                      its computation.
    - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
                                      multiplication in two grouped MMs used in
                                      the fused MoE operation.
    """
    return torch.ops._C.get_cutlass_pplx_moe_mm_data(
1105
1106
1107
1108
1109
1110
1111
1112
1113
        expert_offsets,
        problem_sizes1,
        problem_sizes2,
        expert_num_tokens,
        num_local_experts,
        padded_m,
        n,
        k,
    )
1114
1115


1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
def cutlass_moe_mm(
    out_tensors: torch.Tensor,
    a_tensors: torch.Tensor,
    b_tensors: torch.Tensor,
    a_scales: torch.Tensor,
    b_scales: torch.Tensor,
    expert_offsets: torch.Tensor,
    problem_sizes: torch.Tensor,
    a_strides: torch.Tensor,
    b_strides: torch.Tensor,
    c_strides: torch.Tensor,
    per_act_token: bool,
    per_out_ch: bool,
):
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    """
    A single grouped matrix multiplication used in CUTLASS-based fused MoE.
    The function executes fp8-quantized OUT = AB matrix multiplication.

    - expert_offsets: Indices that mark at which token index each expert begins
                      its computation. The number of tokens computed with
                      expert E is expert_offsets[E + 1] - expert_offsets[E]
    - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
                     MMs used in the fused MoE operation.
    - a/b/c_strides: The data strides passed to grouped matrix multiplication.
    """
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
    return torch.ops._C.cutlass_moe_mm(
        out_tensors,
        a_tensors,
        b_tensors,
        a_scales,
        b_scales,
        expert_offsets,
        problem_sizes,
        a_strides,
        b_strides,
        c_strides,
        per_act_token,
        per_out_ch,
    )
1155
1156


1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
def cutlass_fp4_moe_mm(
    out_tensors: torch.Tensor,
    a_tensors: torch.Tensor,
    b_tensors: torch.Tensor,
    a_scales: torch.Tensor,
    b_scales: torch.Tensor,
    alphas: torch.Tensor,
    problem_sizes: torch.Tensor,
    expert_offsets: torch.Tensor,
    sf_offsets: torch.Tensor,
):
1168
    """
1169
    An FP4 Blockscaled Group Gemm that takes in  a_tensors, b_tensors and runs
1170
1171
1172
1173
1174
1175
    the gemms for each combination based on the specified problem sizes.

    This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
    - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
                     input and expert weights.
    - a_/b_scales: The blockscales in FP8-E4M3 precision
1176
1177
1178
1179
    - expert_offsets/sf_offsets: Indices that mark at which token index
                    each expert begins its computation. The number of tokens
                    computed with expert E is expert_offsets[E + 1] -
                    expert_offsets[E] And the sf_size per expert is
1180
1181
1182
1183
                    sf_offset[E+1] - sf_offset[E]
    - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
                     MMs used in the fused MoE operation.
    """
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
    return torch.ops._C.cutlass_fp4_group_mm(
        out_tensors,
        a_tensors,
        b_tensors,
        a_scales,
        b_scales,
        alphas,
        problem_sizes,
        expert_offsets,
        sf_offsets,
    )
1195
1196


1197
# gptq_marlin
1198
1199
1200
1201
1202
1203
def gptq_marlin_repack(
    b_q_weight: torch.Tensor,
    perm: torch.Tensor,
    size_k: int,
    size_n: int,
    num_bits: int,
1204
    is_a_8bit: bool = False,
1205
) -> torch.Tensor:
1206
1207
1208
    return torch.ops._C.gptq_marlin_repack(
        b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit
    )
1209
1210


1211
1212
1213
1214
1215
1216
1217
1218
1219
if hasattr(torch.ops._C, "gptq_marlin_repack"):

    @register_fake("_C::gptq_marlin_repack")
    def _gptq_marlin_repack_fake(
        b_q_weight: torch.Tensor,
        perm: torch.Tensor,
        size_k: torch.SymInt,
        size_n: torch.SymInt,
        num_bits: int,
1220
        is_a_8bit: bool = False,
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
    ) -> torch.Tensor:
        pack_factor = 32 // num_bits
        marlin_tile_size = 16
        return torch.empty(
            (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor),
            dtype=b_q_weight.dtype,
            device=b_q_weight.device,
        )


# awq_marlin
1232
def awq_marlin_repack(
1233
1234
1235
1236
1237
    b_q_weight: torch.Tensor,
    size_k: int,
    size_n: int,
    num_bits: int,
    is_a_8bit: bool = False,
1238
) -> torch.Tensor:
1239
1240
1241
    return torch.ops._C.awq_marlin_repack(
        b_q_weight, size_k, size_n, num_bits, is_a_8bit
    )
1242
1243


1244
1245
1246
1247
1248
1249
1250
1251
if hasattr(torch.ops._C, "awq_marlin_repack"):

    @register_fake("_C::awq_marlin_repack")
    def _awq_marlin_repack_fake(
        b_q_weight: torch.Tensor,
        size_k: torch.SymInt,
        size_n: torch.SymInt,
        num_bits: int,
1252
        is_a_8bit: bool = False,
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
    ) -> torch.Tensor:
        pack_factor = 32 // num_bits
        marlin_tile_size = 16
        return torch.empty(
            (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor),
            dtype=b_q_weight.dtype,
            device=b_q_weight.device,
        )


1263
1264
1265
1266
1267
1268
def gptq_marlin_moe_repack(
    b_q_weight: torch.Tensor,
    perm: torch.Tensor,
    size_k: int,
    size_n: int,
    num_bits: int,
1269
    is_a_8bit: bool = False,
1270
) -> torch.Tensor:
1271
1272
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
1273
1274
1275
1276
1277
    output = torch.empty(
        (num_experts, size_k // 16, size_n * (num_bits // 2)),
        device=b_q_weight.device,
        dtype=b_q_weight.dtype,
    )
1278
    for e in range(num_experts):
1279
        output[e] = torch.ops._C.gptq_marlin_repack(
1280
            b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit
1281
        )
1282
1283
1284
    return output


1285
1286
1287
1288
1289
1290
def awq_marlin_moe_repack(
    b_q_weight: torch.Tensor,
    perm: torch.Tensor,
    size_k: int,
    size_n: int,
    num_bits: int,
1291
    is_a_8bit: bool = False,
1292
) -> torch.Tensor:
1293
1294
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
1295
1296
1297
1298
1299
    output = torch.empty(
        (num_experts, size_k // 16, size_n * (num_bits // 2)),
        device=b_q_weight.device,
        dtype=b_q_weight.dtype,
    )
1300
    for e in range(num_experts):
1301
        output[e] = torch.ops._C.awq_marlin_repack(
1302
            b_q_weight[e], size_k, size_n, num_bits, is_a_8bit
1303
        )
1304
1305
1306
    return output


1307
1308
1309
1310
1311
1312
1313
1314
def marlin_int4_fp8_preprocess(
    qweight: torch.Tensor,
    qzeros_or_none: torch.Tensor | None = None,
    inplace: bool = False,
):
    return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)


1315
1316
def gptq_marlin_gemm(
    a: torch.Tensor,
1317
    c: torch.Tensor | None,
1318
    b_q_weight: torch.Tensor,
1319
    b_bias: torch.Tensor | None,
1320
    b_scales: torch.Tensor,
1321
    a_scales: torch.Tensor | None,
1322
1323
1324
1325
    global_scale: torch.Tensor | None,
    b_zeros: torch.Tensor | None,
    g_idx: torch.Tensor | None,
    perm: torch.Tensor | None,
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
    workspace: torch.Tensor,
    b_q_type: ScalarType,
    size_m: int,
    size_n: int,
    size_k: int,
    is_k_full: bool = True,
    use_atomic_add: bool = False,
    use_fp32_reduce: bool = False,
    is_zp_float: bool = False,
) -> torch.Tensor:
    return torch.ops._C.gptq_marlin_gemm(
        a,
        c,
        b_q_weight,
        b_bias,
        b_scales,
1342
        a_scales,
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
        global_scale,
        b_zeros,
        g_idx,
        perm,
        workspace,
        b_q_type.id,
        size_m,
        size_n,
        size_k,
        is_k_full,
        use_atomic_add,
        use_fp32_reduce,
        is_zp_float,
    )
1357
1358


1359
# machete
1360
def machete_supported_schedules(
1361
1362
    a_type: torch.dtype,
    b_type: ScalarType,
1363
1364
1365
1366
1367
    group_scales_type: torch.dtype | None,
    group_zeros_type: torch.dtype | None = None,
    channel_scales_type: torch.dtype | None = None,
    token_scales_type: torch.dtype | None = None,
    out_type: torch.dtype | None = None,
1368
) -> list[str]:
1369
    return torch.ops._C.machete_supported_schedules(
1370
1371
1372
1373
1374
1375
1376
1377
        a_type,
        b_type.id,
        group_scales_type,
        group_zeros_type,
        channel_scales_type,
        token_scales_type,
        out_type,
    )
1378
1379
1380


def machete_mm(
1381
1382
1383
1384
    a: torch.Tensor,
    # b_q Should be the tensor returned by machete_prepack_B
    b_q: torch.Tensor,
    b_type: ScalarType,
1385
1386
1387
1388
1389
1390
1391
    out_type: torch.dtype | None = None,
    b_group_scales: torch.Tensor | None = None,
    b_group_zeros: torch.Tensor | None = None,
    b_group_size: int | None = None,
    b_channel_scales: torch.Tensor | None = None,
    a_token_scales: torch.Tensor | None = None,
    schedule: str | None = None,
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
) -> torch.Tensor:
    return torch.ops._C.machete_mm(
        a,
        b_q,
        b_type.id,
        out_type,
        b_group_scales,
        b_group_zeros,
        b_group_size,
        b_channel_scales,
        a_token_scales,
        schedule,
    )
1405
1406
1407


def machete_prepack_B(
1408
1409
1410
    b_q_weight: torch.Tensor,
    a_type: torch.dtype,
    b_type: ScalarType,
1411
    group_scales_type: torch.dtype | None,
1412
1413
1414
1415
) -> torch.Tensor:
    return torch.ops._C.machete_prepack_B(
        b_q_weight, a_type, b_type.id, group_scales_type
    )
1416
1417


1418
1419
# CUTLASS W4A8
def cutlass_w4a8_mm(
1420
1421
1422
1423
1424
1425
1426
    a: torch.Tensor,
    # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
    b_q: torch.Tensor,
    b_group_scales: torch.Tensor,
    b_group_size: int,
    b_channel_scales: torch.Tensor,
    a_token_scales: torch.Tensor,
1427
1428
    out_type: torch.dtype | None = None,
    maybe_schedule: str | None = None,
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
) -> torch.Tensor:
    return torch.ops._C.cutlass_w4a8_mm(
        a,
        b_q,
        b_group_scales,
        b_group_size,
        b_channel_scales,
        a_token_scales,
        out_type,
        maybe_schedule,
    )
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449


def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor:
    return torch.ops._C.cutlass_pack_scale_fp8(scales)


def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
    return torch.ops._C.cutlass_encode_and_reorder_int4b(b)


1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
def cutlass_w4a8_moe_mm(
    out_tensors: torch.Tensor,
    a_tensors: torch.Tensor,
    b_tensors: torch.Tensor,
    a_scales: torch.Tensor,
    b_scales: torch.Tensor,
    b_group_scales: torch.Tensor,
    b_group_size: int,
    expert_offsets: torch.Tensor,
    problem_sizes: torch.Tensor,
    a_strides: torch.Tensor,
    b_strides: torch.Tensor,
    c_strides: torch.Tensor,
    group_scale_strides: torch.Tensor,
    maybe_schedule: str | None = None,
):
    """
    Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the
    W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8)
    and both per-channel + per-token scaling in the epilogue.

    Args:
        out_tensors:
            Output buffer for all experts (updated in-place).
        a_tensors:
            FP8 (E4M3FN) activations for all experts.
        b_tensors:
            INT4-packed weight matrix for all experts, packed to INT32
        a_scales:
            Per-token FP8 activation scales, applied in the epilogue.
        b_scales:
            Per-channel FP8 weight scales for each expert, applied in the epilogue.
        b_group_scales:
            FP8 scale values for group-wise INT4 weight blocks.
        b_group_size:
            Number of elements grouped under each entry of b_group_scales.
        expert_offsets:
            Cumulative token offsets
        problem_sizes:
            Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher.
        a/b/c/group_scale_strides:
            Strides describing the memory layout of the input tensors.
        maybe_schedule:
            Optional override to choose a specific kernel or epilogue schedule.

    Returns:
        out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result.
    """
    return torch.ops._C.cutlass_w4a8_moe_mm(
        out_tensors,
        a_tensors,
        b_tensors,
        a_scales,
        b_scales,
        b_group_scales,
        b_group_size,
        expert_offsets,
        problem_sizes,
        a_strides,
        b_strides,
        c_strides,
        group_scale_strides,
        maybe_schedule,
    )


def cutlass_encode_and_reorder_int4b_grouped(
    b_tensors: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors)


1522
if hasattr(torch.ops._C, "permute_cols"):
1523

1524
    @register_fake("_C::permute_cols")
1525
    def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
1526
1527
1528
1529
1530
1531
1532
        return torch.empty_like(a)


def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
    return torch.ops._C.permute_cols(a, perm)


1533
1534
# fp4
def scaled_fp4_quant(
1535
1536
1537
    input: torch.Tensor,
    input_global_scale: torch.Tensor,
    backend: str = "none",
1538
) -> tuple[torch.Tensor, torch.Tensor]:
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
    """
    Quantize input tensor to FP4 and return quantized tensor and scale.

    This function quantizes the last dimension of the given tensor `input`. For
    every 16 consecutive elements, a single dynamically computed scaling factor
    is shared. This scaling factor is quantized using the `input_global_scale`
    and is stored in a swizzled layout (see
    https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).

    Args:
        input: The input tensor to be quantized to FP4
        input_global_scale: A scalar scaling factor for the entire tensor.
1551
        use_8x4_sf_layout: Whether to use the 8x4 or 128x4 layout for the scaling
1552
1553

    Returns:
1554
        tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
1555
1556
1557
            two values are packed into a uint8 and float8_e4m3 scaling factors
            in the sizzled layout.
    """
1558
    assert not current_platform.is_rocm()
1559
    assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
1560
1561
1562
1563
1564
1565
    other_dims = 1 if input.ndim == 1 else -1
    input = input.reshape(other_dims, input.shape[-1])
    m, n = input.shape
    block_size = 16
    device = input.device

1566
    assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
1567
    assert input.dtype in (torch.float16, torch.bfloat16), (
1568
1569
        f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
    )
1570

1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
    use_8x4_sf_layout = True if "trtllm" in backend and m <= 32 else False  # noqa: SIM210

    if use_8x4_sf_layout:
        output, output_scale = flashinfer_quant_nvfp4_8x4_sf_layout(
            input, input_global_scale
        )
    else:
        # Two fp4 values will be packed into an uint8.
        output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)

        # We use the rounded values to store the swizzled values. Due to the
        # requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
        # So, we first pad the scales to multiples of 128 and 4. Then, the scales
        # (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
        # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
        round_up = lambda x, y: (x + y - 1) // y * y
        rounded_m = round_up(m, 128)
        scale_n = n // block_size
        rounded_n = round_up(scale_n, 4)
        output_scale = torch.empty(
            (rounded_m, rounded_n // 4), device=device, dtype=torch.int32
        )

        torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale)
1595
1596
1597
1598
1599

    output_scale = output_scale.view(torch.float8_e4m3fn)
    return output, output_scale


1600
1601
1602
1603
1604
1605
1606
1607
def scaled_fp4_experts_quant(
    input_tensor: torch.Tensor,
    input_global_scale: torch.Tensor,
    expert_offsets: torch.Tensor,
    blockscale_offsets: torch.Tensor,
    topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
1608
    Quantize input tensor to NVFP4 and return quantized tensor and scale, for
1609
1610
    packed MoE Inputs.
    Args:
1611
        input_tensor: The input tensor to be quantized to NVFP4
1612
1613
1614
1615
        input_global_scale: A scalar scaling factor for the entire tensor.
        expert_offsets: The expert offsets tensor
        blockscale_offsets: The blockscale offsets tensor
    Outputs:
1616
        output: The quantized tensor in NVFP4
1617
1618
1619
1620
        output_scales: The blockscale tensor in FP8-E4M3
    """
    assert not current_platform.is_rocm()
    assert input_tensor.ndim == 2, (
1621
1622
        f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
    )
1623

1624
1625
1626
1627
1628
    # Control the maximum number of tokens per expert supported by the
    # NVFP4 MoE Expert Quantization. This is used to prevent the kernel
    # from running out of memory. This value can also be increased to support
    # larger models.
    MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
1629
1630
    m_numtopk, k = input_tensor.shape

1631
    assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
1632
1633
1634
        f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
        f"{MAX_TOKENS_PER_EXPERT})"
        f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
1635
1636
        f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value."
    )
1637
1638
1639
1640
    scales_k = k // 16
    padded_k = (scales_k + (4 - 1)) // 4

    # output is uint8 and packed fp4 values
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
    output = torch.empty(
        m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
    )
    output_scales = torch.empty(
        MAX_TOKENS_PER_EXPERT * topk,
        padded_k,
        dtype=torch.int32,
        device=input_tensor.device,
    )
    torch.ops._C.scaled_fp4_experts_quant(
        output,
        output_scales,
        input_tensor,
        input_global_scale,
        expert_offsets,
        blockscale_offsets,
    )
1658
1659
1660
1661
    output_scales = output_scales.view(torch.float8_e4m3fn)
    return output, output_scales


1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
def silu_and_mul_scaled_fp4_experts_quant(
    input_tensor: torch.Tensor,
    input_global_scale: torch.Tensor,
    expert_offsets: torch.Tensor,
    blockscale_offsets: torch.Tensor,
    topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Fused SiLU+Mul+NVFP4 quantization for MoE intermediate activations.

    Args:
        input_tensor: The input tensor with gate || up layout [m_topk, k*2]
        input_global_scale: A per-expert scaling factor [n_experts]
        expert_offsets: The expert offsets tensor [n_experts+1]
        blockscale_offsets: The blockscale offsets tensor [n_experts+1]
        topk: Number of top-k experts selected
    Outputs:
        output: The quantized tensor in NVFP4 [m_topk, k/2]
        output_scales: The blockscale tensor in FP8-E4M3
    """
    assert not current_platform.is_rocm()
    assert input_tensor.ndim == 2, (
        f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
    )

    # Control the maximum number of tokens per expert supported by the
    # NVFP4 MoE Expert Quantization. This is used to prevent the kernel
    # from running out of memory. This value can also be increased to support
    # larger models.
    MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
    m_numtopk, k_times_2 = input_tensor.shape
    assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)"
    k = k_times_2 // 2

    assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
        f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
        f"{MAX_TOKENS_PER_EXPERT})"
        f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
        f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value."
    )
    scales_k = k // 16
    padded_k = (scales_k + (4 - 1)) // 4

    # output is uint8 and packed fp4 values
    output = torch.empty(
        m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
    )
    output_scales = torch.empty(
        MAX_TOKENS_PER_EXPERT * topk,
        padded_k,
        dtype=torch.int32,
        device=input_tensor.device,
    )
    torch.ops._C.silu_and_mul_scaled_fp4_experts_quant(
        output,
        output_scales,
        input_tensor,
        input_global_scale,
        expert_offsets,
        blockscale_offsets,
    )
    output_scales = output_scales.view(torch.float8_e4m3fn)
    return output, output_scales


1727
# fp8
1728
1729
def scaled_fp8_quant(
    input: torch.Tensor,
1730
1731
1732
    scale: torch.Tensor | None = None,
    num_token_padding: int | None = None,
    scale_ub: torch.Tensor | None = None,
1733
    use_per_token_if_dynamic: bool = False,
1734
    output: torch.Tensor | None = None,
1735
    group_shape: tuple[int, int] | None = None,
1736
) -> tuple[torch.Tensor, torch.Tensor]:
1737
1738
1739
1740
1741
1742
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
1743
    optional padding of the output tensors for downstream kernels that
1744
1745
1746
    will benefit from padding.

    Args:
1747
1748
1749
1750
1751
1752
1753
        input: The input tensor to be quantized to FP8 (must be 2D: [M, N])
        scale: Optional scaling factor for the FP8 quantization. Supports:
            - 0D or [1]: per-tensor scaling
            - 1D: requires explicit group_shape to disambiguate per-channel
              vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token)
            - 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for
              DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128))
1754
        scale_ub: Optional upper bound for scaling factor in dynamic
1755
            per token case
1756
        num_token_padding: If specified, pad the first dimension
1757
            of the output to at least this value.
1758
        use_per_token_if_dynamic: Whether to do per_tensor or per_token
1759
            in the dynamic quantization case.
1760
1761
1762
1763
        group_shape: Optional tuple (group_m, group_n) specifying the group
            shape for static quantization. Use -1 for "full extent" (e.g.,
            (-1, -1) for per-tensor, (-1, 1) for per-channel, etc.)
            Required for 1D scales; optional for 2D scales.
1764
1765

    Returns:
1766
        tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
1767
1768
            scaling factor.
    """
1769
    # This code assumes batch_dim and num_tokens are flattened
1770
    assert input.ndim == 2
1771
    shape: tuple[int, int] | torch.Size = input.shape
1772
1773
    # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
    out_dtype: torch.dtype = current_platform.fp8_dtype()
1774
1775
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
1776
1777
1778
    if output is None:
        output = torch.empty(shape, device=input.device, dtype=out_dtype)
    else:
1779
        assert num_token_padding is None, "padding not supported if output passed in"
1780
        assert output.dtype == out_dtype
1781

1782
    if scale is None:
1783
        if use_per_token_if_dynamic:
1784
            scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
1785
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
1786
1787
                output, input, scale, scale_ub
            )
1788
        else:
1789
            scale = torch.empty(1, device=input.device, dtype=torch.float32)
1790
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
1791
    else:
1792
        torch.ops._C.static_scaled_fp8_quant(output, input, scale, group_shape)
1793

1794
    return output, scale
1795
1796


1797
1798
# gptq allspark
def allspark_repack_weight(
1799
1800
    qweight: torch.Tensor,
    scale: torch.Tensor,
1801
    zero_point: torch.Tensor | None = None,
1802
    has_zp: bool = False,
1803
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1804
    """
1805
    Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
1806
1807
1808
1809
1810
1811
1812
1813
    for Ampere W8A16 Fused Gemm kernel

    Args:
        qweight: uint8 weight tensor, original k x n format.
        scale: fp16/bf16 weight scale tensor, 1 x n format.
        zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
            Must be provided for asymmetric quantization.
        has_zp: if use symmetric quantization, has_zp = False.
1814
1815
            if use asymmetric quantization, has_zp = True.

1816
    Returns:
1817
        tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] :
1818
1819
1820
1821
1822
1823
            rearranged weight, scale, and optionally zero_point.
    """
    K = qweight.shape[0]
    N = qweight.shape[1]
    N_32align = (N + 32 - 1) // 32 * 32

1824
1825
1826
1827
    qweight_reorder = torch.empty(
        (N_32align, K), device=qweight.device, dtype=qweight.dtype
    )
    scale_reorder = torch.empty((1, N_32align), device=scale.device, dtype=scale.dtype)
1828
1829
1830
    zero_point_reorder = None
    if has_zp:
        assert zero_point is not None, (
1831
1832
1833
1834
1835
            "zero_point must be provided for asymmetric quantization."
        )
        zero_point_reorder = torch.empty(
            (1, N_32align), device=zero_point.device, dtype=zero_point.dtype
        )
1836
1837

    torch.ops._C.rearrange_kn_weight_as_n32k16_order(
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
        qweight,
        scale,
        zero_point,
        has_zp,
        qweight_reorder,
        scale_reorder,
        zero_point_reorder,
        K,
        N,
        N_32align,
    )
1849
1850
1851
1852

    return qweight_reorder, scale_reorder, zero_point_reorder


1853
1854
1855
1856
def allspark_w8a16_gemm(
    a: torch.Tensor,
    b_qweight: torch.Tensor,
    b_scales: torch.Tensor,
1857
    b_qzeros: torch.Tensor | None,
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
    n: int,
    group_size: int,
    sm_count: int,
    sm_version: int,
    CUBLAS_M_THRESHOLD: int,
    has_zp: bool,
    n32k16_reorder: bool,
) -> torch.Tensor:
    return torch.ops._C.allspark_w8a16_gemm(
        a,
        b_qweight,
        b_scales,
        b_qzeros,
        n,
        group_size,
        sm_count,
        sm_version,
        CUBLAS_M_THRESHOLD,
        has_zp,
        n32k16_reorder,
    )
1879
1880


1881
# int8
1882
def scaled_int8_quant(
1883
    input: torch.Tensor,
1884
1885
    scale: torch.Tensor | None = None,
    azp: torch.Tensor | None = None,
1886
    symmetric: bool = True,
1887
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
1888
    """
1889
    Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
1890
1891
1892

    Args:
        input: The input tensor to be quantized to int8.
1893
1894
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
1895
1896
1897
        azp: Optional zero-point for the int8 quantization.
            Must be provided for asymmetric quantization if `scale` is provided.
        symmetric: Whether to use symmetric quantization (scale only, azp ignored).
1898
1899

    Returns:
1900
      tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
1901
    """
1902
1903
1904
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
1905
1906
1907
        assert symmetric == (azp is None), (
            "azp must only be provided for asymmetric quantization."
        )
1908
        torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
1909
        return output, scale, azp
1910
1911

    # dynamic-per-token quantization.
1912
1913
1914
1915
1916
1917
1918
    input_scales = torch.empty(
        (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
    )
    input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
    torch.ops._C.dynamic_scaled_int8_quant(
        output, input.contiguous(), input_scales, input_azp
    )
1919
    return output, input_scales, input_azp
1920
1921


1922
# gguf
1923
def ggml_dequantize(
1924
    W: torch.Tensor, quant_type: int, m: int, n: int, dtype: torch.dtype | None
1925
) -> torch.Tensor:
1926
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype)
1927
1928
1929
1930
1931
1932
1933


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
1934
) -> torch.Tensor:
1935
1936
1937
1938
1939
1940
1941
1942
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
1943
) -> torch.Tensor:
1944
1945
1946
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
def ggml_moe_a8(
    X: torch.Tensor,
    W: torch.Tensor,
    sorted_token_ids: torch.Tensor,
    expert_ids: torch.Tensor,
    num_tokens_post_padded: torch.Tensor,
    quant_type: int,
    row: int,
    top_k: int,
    tokens: int,
) -> torch.Tensor:
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
    return torch.ops._C.ggml_moe_a8(
        X,
        W,
        sorted_token_ids,
        expert_ids,
        num_tokens_post_padded,
        quant_type,
        row,
        top_k,
        tokens,
    )
1969
1970


1971
1972
1973
1974
1975
1976
1977
1978
1979
def ggml_moe_a8_vec(
    X: torch.Tensor,
    W: torch.Tensor,
    topk_ids: torch.Tensor,
    top_k: int,
    quant_type: int,
    row: torch.SymInt,
    tokens: torch.SymInt,
) -> torch.Tensor:
1980
    return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, tokens)
1981
1982


1983
1984
1985
1986
def ggml_moe_get_block_size(quant_type: int) -> int:
    return torch.ops._C.ggml_moe_get_block_size(quant_type)


1987
# mamba
1988
1989
1990
1991
1992
1993
def selective_scan_fwd(
    u: torch.Tensor,
    delta: torch.Tensor,
    A: torch.Tensor,
    B: torch.Tensor,
    C: torch.Tensor,
1994
1995
1996
    D_: torch.Tensor | None,
    z_: torch.Tensor | None,
    delta_bias_: torch.Tensor | None,
1997
    delta_softplus: bool,
1998
1999
2000
    query_start_loc: torch.Tensor | None,
    cache_indices: torch.Tensor | None,
    has_initial_state: torch.Tensor | None,
2001
2002
    ssm_states: torch.Tensor,
    pad_slot_id: int,
2003
2004
2005
2006
    block_size: int = 1024,
    block_idx_first_scheduled_token: torch.Tensor | None = None,
    block_idx_last_scheduled_token: torch.Tensor | None = None,
    initial_state_idx: torch.Tensor | None = None,
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
):
    torch.ops._C.selective_scan_fwd(
        u,
        delta,
        A,
        B,
        C,
        D_,
        z_,
        delta_bias_,
        delta_softplus,
        query_start_loc,
        cache_indices,
        has_initial_state,
        ssm_states,
        pad_slot_id,
2023
2024
2025
2026
        block_size,
        block_idx_first_scheduled_token,
        block_idx_last_scheduled_token,
        initial_state_idx,
2027
    )
2028
2029


2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu)
# are unable to properly handle non-contiguous
# tensors.  It might be a good TODO(rasmith) to augment these kernels
# to be able to handle non-contiguous kernels for better performance.
def rocm_enforce_contiguous_skinny_gemm_inputs(
    a: torch.Tensor, b: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    a = a.contiguous()  # no-op if already contiguous, else clone
    b = b.contiguous()  # no-op if already contiguous, else clone
    return a, b


2042
# ROCm skinny gemms
2043
def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor:
2044
    a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
2045
2046
2047
    return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)


2048
2049
2050
def wvSplitK(
    a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
2051
    a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
2052
2053
2054
    return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)


2055
2056
2057
def wvSplitKrc(
    a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
2058
    a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
2059
2060
2061
    return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count)


2062
2063
2064
2065
2066
2067
2068
2069
2070
def wvSplitKQ(
    a: torch.Tensor,
    b: torch.Tensor,
    out_dtype: torch.dtype,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    cu_count: int,
    bias: torch.Tensor = None,
) -> torch.Tensor:
2071
    a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
2072
    out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device)
2073
    torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
2074
2075
2076
    return out


2077
# moe
2078
2079
2080
2081
def moe_sum(input: torch.Tensor, output: torch.Tensor):
    torch.ops._moe_C.moe_sum(input, output)


2082
2083
2084
2085
2086
2087
2088
def moe_align_block_size(
    topk_ids: torch.Tensor,
    num_experts: int,
    block_size: int,
    sorted_token_ids: torch.Tensor,
    experts_ids: torch.Tensor,
    num_tokens_post_pad: torch.Tensor,
2089
    expert_map: torch.Tensor | None = None,
2090
2091
2092
2093
2094
2095
2096
2097
) -> None:
    torch.ops._moe_C.moe_align_block_size(
        topk_ids,
        num_experts,
        block_size,
        sorted_token_ids,
        experts_ids,
        num_tokens_post_pad,
2098
        expert_map,
2099
    )
2100
2101


2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
def batched_moe_align_block_size(
    max_tokens_per_batch: int,
    block_size: int,
    expert_num_tokens: torch.Tensor,
    sorted_ids: torch.Tensor,
    expert_ids: torch.Tensor,
    num_tokens_post_pad: torch.Tensor,
) -> None:
    torch.ops._moe_C.batched_moe_align_block_size(
        max_tokens_per_batch,
        block_size,
        expert_num_tokens,
        sorted_ids,
        expert_ids,
        num_tokens_post_pad,
    )


2120
2121
2122
2123
2124
2125
def moe_lora_align_block_size(
    topk_ids: torch.Tensor,
    token_lora_mapping: torch.Tensor,
    num_experts: int,
    block_size: int,
    max_loras: int,
2126
2127
    max_num_tokens_padded: int,
    max_num_m_blocks: int,
2128
2129
2130
    sorted_token_ids: torch.Tensor,
    experts_ids: torch.Tensor,
    num_tokens_post_pad: torch.Tensor,
2131
2132
    adapter_enabled: torch.Tensor,
    lora_ids: torch.Tensor,
gnovack's avatar
gnovack committed
2133
    expert_map: torch.Tensor | None = None,
2134
2135
2136
2137
2138
2139
2140
) -> None:
    torch.ops._moe_C.moe_lora_align_block_size(
        topk_ids,
        token_lora_mapping,
        num_experts,
        block_size,
        max_loras,
2141
2142
        max_num_tokens_padded,
        max_num_m_blocks,
2143
2144
2145
        sorted_token_ids,
        experts_ids,
        num_tokens_post_pad,
2146
2147
        adapter_enabled,
        lora_ids,
gnovack's avatar
gnovack committed
2148
        expert_map,
2149
2150
2151
    )


2152
2153
2154
2155
2156
def moe_wna16_gemm(
    input: torch.Tensor,
    output: torch.Tensor,
    b_qweight: torch.Tensor,
    b_scales: torch.Tensor,
2157
2158
    b_qzeros: torch.Tensor | None,
    topk_weights: torch.Tensor | None,
2159
2160
2161
2162
2163
2164
2165
2166
2167
    sorted_token_ids: torch.Tensor,
    experts_ids: torch.Tensor,
    num_tokens_post_pad: torch.Tensor,
    top_k: int,
    BLOCK_SIZE_M: int,
    BLOCK_SIZE_N: int,
    BLOCK_SIZE_K: int,
    bit: int,
) -> torch.Tensor:
2168
2169
    if not current_platform.is_cuda():
        raise NotImplementedError(
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
            "The optimized moe_wna16_gemm kernel is only available on CUDA platforms"
        )
    torch.ops._moe_C.moe_wna16_gemm(
        input,
        output,
        b_qweight,
        b_scales,
        b_qzeros,
        topk_weights,
        sorted_token_ids,
        experts_ids,
        num_tokens_post_pad,
        top_k,
        BLOCK_SIZE_M,
        BLOCK_SIZE_N,
        BLOCK_SIZE_K,
        bit,
    )
2188
2189


2190
2191
2192
2193
2194
def topk_softmax(
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    token_expert_indices: torch.Tensor,
    gating_output: torch.Tensor,
2195
    renormalize: bool = False,
2196
    e_score_correction_bias: torch.Tensor | None = None,
2197
2198
) -> None:
    torch.ops._moe_C.topk_softmax(
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
        topk_weights,
        topk_ids,
        token_expert_indices,
        gating_output,
        renormalize,
        e_score_correction_bias,
    )


def topk_sigmoid(
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    token_expert_indices: torch.Tensor,
    gating_output: torch.Tensor,
    renormalize: bool = False,
    e_score_correction_bias: torch.Tensor | None = None,
) -> None:
    torch.ops._moe_C.topk_sigmoid(
        topk_weights,
        topk_ids,
        token_expert_indices,
        gating_output,
        renormalize,
        e_score_correction_bias,
2223
    )
2224
2225


2226
2227
2228
2229
2230
2231
2232
def grouped_topk(
    scores: torch.Tensor,
    num_expert_group: int,
    topk_group: int,
    topk: int,
    renormalize: bool,
    routed_scaling_factor: float,
2233
2234
    bias: torch.Tensor,
    scoring_func: int = 0,
2235
):
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
    """
    Perform grouped top-k routing for mixture of experts.

    Args:
        scores: Raw inputs (logits if scoring_func=1, scores if scoring_func=0)
        num_expert_group: Number of expert groups
        topk_group: Number of groups to select
        topk: Number of experts to select per token
        renormalize: Whether to renormalize the output weights
        routed_scaling_factor: Scaling factor for routing weights
        bias: Bias tensor (e_score_correction_bias). Always fused in kernel.
        scoring_func: 0=none (no activation), 1=sigmoid
    """
2249
    if not current_platform.is_cuda():
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
        raise NotImplementedError(
            "The fused grouped_topk kernel is only available on CUDA platforms"
        )
    return torch.ops._moe_C.grouped_topk(
        scores,
        num_expert_group,
        topk_group,
        topk,
        renormalize,
        routed_scaling_factor,
2260
2261
        bias,
        scoring_func,
2262
2263
2264
2265
2266
    )


def moe_wna16_marlin_gemm(
    input: torch.Tensor,
2267
    output: torch.Tensor | None,
2268
    b_qweight: torch.Tensor,
2269
    b_bias: torch.Tensor | None,
2270
    b_scales: torch.Tensor,
2271
    a_scales: torch.Tensor | None,
2272
2273
2274
2275
    global_scale: torch.Tensor | None,
    b_qzeros: torch.Tensor | None,
    g_idx: torch.Tensor | None,
    perm: torch.Tensor | None,
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
    workspace: torch.Tensor,
    sorted_token_ids: torch.Tensor,
    expert_ids: torch.Tensor,
    num_tokens_past_padded: torch.Tensor,
    topk_weights: torch.Tensor,
    moe_block_size: int,
    top_k: int,
    mul_topk_weights: bool,
    b_q_type: ScalarType,
    size_m: int,
    size_n: int,
    size_k: int,
    is_k_full: bool,
    use_atomic_add: bool,
    use_fp32_reduce: bool,
    is_zp_float: bool,
2292
2293
2294
    thread_k: int = -1,
    thread_n: int = -1,
    blocks_per_sm: int = -1,
2295
) -> torch.Tensor:
2296
    return torch.ops._moe_C.moe_wna16_marlin_gemm(
2297
2298
2299
2300
2301
        input,
        output,
        b_qweight,
        b_bias,
        b_scales,
2302
        a_scales,
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
        global_scale,
        b_qzeros,
        g_idx,
        perm,
        workspace,
        sorted_token_ids,
        expert_ids,
        num_tokens_past_padded,
        topk_weights,
        moe_block_size,
        top_k,
        mul_topk_weights,
        b_q_type.id,
        size_m,
        size_n,
        size_k,
        is_k_full,
        use_atomic_add,
        use_fp32_reduce,
        is_zp_float,
2323
2324
2325
        thread_k,
        thread_n,
        blocks_per_sm,
2326
    )
2327
2328


2329
if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
2330

2331
    @register_fake("_moe_C::marlin_gemm_moe")
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
    def marlin_gemm_moe_fake(
        a: torch.Tensor,
        b_q_weights: torch.Tensor,
        sorted_ids: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        b_scales: torch.Tensor,
        b_zero_points: torch.Tensor,
        g_idx: torch.Tensor,
        perm: torch.Tensor,
        workspace: torch.Tensor,
        b_q_type: ScalarType,
        size_m: torch.SymInt,
        size_n: torch.SymInt,
        size_k: torch.SymInt,
        is_k_full: bool,
        num_experts: int,
        topk: int,
        moe_block_size: int,
        replicate_input: bool,
        apply_weights: bool,
    ) -> torch.Tensor:
        return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device)
2355

2356
    @register_fake("_moe_C::moe_wna16_marlin_gemm")
2357
2358
    def moe_wna16_marlin_gemm_fake(
        input: torch.Tensor,
2359
        output: torch.Tensor | None,
2360
        b_qweight: torch.Tensor,
2361
        b_bias: torch.Tensor | None,
2362
        b_scales: torch.Tensor,
2363
2364
        a_scales: torch.Tensor | None,
        global_scale: torch.Tensor | None,
2365
2366
2367
        b_qzeros: torch.Tensor | None,
        g_idx: torch.Tensor | None,
        perm: torch.Tensor | None,
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
        workspace: torch.Tensor,
        sorted_token_ids: torch.Tensor,
        expert_ids: torch.Tensor,
        num_tokens_past_padded: torch.Tensor,
        topk_weights: torch.Tensor,
        moe_block_size: int,
        top_k: int,
        mul_topk_weights: bool,
        b_q_type: ScalarType,
        size_m: int,
        size_n: int,
        size_k: int,
        is_k_full: bool,
        use_atomic_add: bool,
        use_fp32_reduce: bool,
        is_zp_float: bool,
2384
    ):
2385
2386
2387
        return torch.empty(
            (size_m * top_k, size_n), dtype=input.dtype, device=input.device
        )
2388

2389

2390
2391
2392
2393
2394
2395
2396
def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
2397
2398
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
2399
) -> None:
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
    torch.ops._C_cache_ops.reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )
2410
2411


2412
2413
2414
2415
2416
2417
2418
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
2419
2420
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
2421
) -> None:
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
    torch.ops._C_cache_ops.reshape_and_cache_flash(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )
2432
2433


2434
2435
2436
2437
2438
2439
2440
2441
def concat_and_cache_mla(
    kv_c: torch.Tensor,
    k_pe: torch.Tensor,
    kv_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
    scale: torch.Tensor,
) -> None:
2442
2443
2444
    torch.ops._C_cache_ops.concat_and_cache_mla(
        kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
    )
2445
2446


2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
def concat_and_cache_mla_rope_fused(
    positions: torch.Tensor,
    q_pe: torch.Tensor,
    k_pe: torch.Tensor,
    kv_c: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
    slot_mapping: torch.Tensor,
    kv_cache: torch.Tensor,
    kv_cache_dtype: str,
    kv_cache_scale: torch.Tensor,
) -> None:
    torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused(
        positions,
        q_pe,
        k_pe,
        kv_c,
        cos_sin_cache,
        is_neox,
        slot_mapping,
        kv_cache,
        kv_cache_dtype,
        kv_cache_scale,
    )


2473
def swap_blocks(
2474
2475
2476
2477
    src: torch.Tensor,
    dst: torch.Tensor,
    block_size_in_bytes: int,
    block_mapping: torch.Tensor,
2478
) -> None:
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
    """
    Copy specific blocks from one tensor to another.

    This method assumes each of the two input tensors is composed of
    consecutive contiguous blocks, of size block_size_in_bytes.
    i.e. the memory layout for each tensor is:
    [block0] [block1] ... [block N]

    block_mapping determines the subset of blocks to copy of the source tensor,
    and their matching destination block number on the destination tensor.
    block_mapping is expected to be a tensor of shape (num_blocks_to_copy, 2)
    where each block_mapping[i] represents a single copy operation, copying
    block #block_mapping[i][0] from the source tensor
    to block #block_mapping[i][1] on the destination tensor.
    block_mapping should have dtype int64.

    The source and the destination tensors can be either on cpu or gpu,
    but not both on cpu.
    the block mapping tensor must on cpu.
    """
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
2500
2501


2502
2503
2504
def convert_fp8(
    output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
) -> None:
2505
2506
2507
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


2508
def gather_and_maybe_dequant_cache(
2509
2510
2511
2512
    src_cache: torch.Tensor,
    dst: torch.Tensor,
    block_table: torch.Tensor,
    cu_seq_lens: torch.Tensor,
2513
2514
    token_to_seq: torch.Tensor,
    num_tokens: int,
2515
2516
    kv_cache_dtype: str,
    scale: torch.Tensor,
2517
    seq_starts: torch.Tensor | None = None,
2518
) -> None:
2519
    torch.ops._C_cache_ops.gather_and_maybe_dequant_cache(
2520
2521
2522
2523
        src_cache,
        dst,
        block_table,
        cu_seq_lens,
2524
2525
        token_to_seq,
        num_tokens,
2526
2527
2528
2529
        kv_cache_dtype,
        scale,
        seq_starts,
    )
2530
2531


2532
2533
2534
2535
2536
2537
def cp_gather_cache(
    src_cache: torch.Tensor,
    dst: torch.Tensor,
    block_table: torch.Tensor,
    cu_seq_lens: torch.Tensor,
    batch_size: int,
2538
    seq_starts: torch.Tensor | None = None,
2539
2540
2541
2542
) -> None:
    torch.ops._C_cache_ops.cp_gather_cache(
        src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts
    )
2543
2544


2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
def cp_gather_and_upconvert_fp8_kv_cache(
    src_cache: torch.Tensor,
    dst: torch.Tensor,
    block_table: torch.Tensor,
    seq_lens: torch.Tensor,
    workspace_starts: torch.Tensor,
    batch_size: int,
) -> None:
    """Gather and upconvert FP8 KV cache to BF16 workspace.

    Args:
        src_cache: FP8 KV cache [num_blocks, block_size, 656]
        dst: BF16 output workspace [total_tokens, 576]
        block_table: Block indices [num_reqs, max_blocks]
        seq_lens: Sequence lengths [num_reqs]
        workspace_starts: Workspace start offsets [num_reqs]
        batch_size: Number of requests
    """
    torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
        src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
    )


2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
def indexer_k_quant_and_cache(
    k: torch.Tensor,
    kv_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    quant_block_size: int,
    kv_cache_dtype: str,
) -> None:
    torch.ops._C_cache_ops.indexer_k_quant_and_cache(
        k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype
    )
2578
2579


2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
def cp_gather_indexer_k_quant_cache(
    kv_cache: torch.Tensor,
    dst_k: torch.Tensor,
    dst_scale: torch.Tensor,
    block_table: torch.Tensor,
    cu_seq_lens: torch.Tensor,
) -> None:
    torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache(
        kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
    )


2592
2593
2594
2595
2596
2597
2598
def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
2599
2600
        device
    )
2601
2602
2603


# custom ar
2604
2605
2606
2607
2608
2609
2610
2611
2612
def init_custom_ar(
    ipc_tensors: list[torch.Tensor],
    rank_data: torch.Tensor,
    rank: int,
    fully_connected: bool,
) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(
        ipc_tensors, rank_data, rank, fully_connected
    )
2613
2614


2615
2616
2617
2618
2619
2620
2621
2622
def all_reduce(
    fa: int,
    inp: torch.Tensor,
    out: torch.Tensor,
    reg_buffer: int,
    reg_buffer_sz_bytes: int,
) -> None:
    torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
2623

2624
2625
2626
2627
2628
2629
2630
2631
2632

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


2633
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
2634
    return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
2635
2636


2637
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
2638
2639
2640
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


2641
2642
2643
def register_graph_buffers(
    fa: int, handles: list[list[int]], offsets: list[list[int]]
) -> None:
2644
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
2645
2646


2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
    return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)


def open_mem_handle(mem_handle: torch.Tensor):
    return torch.ops._C_custom_ar.open_mem_handle(mem_handle)


def free_shared_buffer(ptr: int) -> None:
    torch.ops._C_custom_ar.free_shared_buffer(ptr)


2659
# quick all reduce
2660
def init_custom_qr(rank: int, world_size: int, qr_max_size: int | None = None) -> int:
2661
2662
2663
2664
2665
2666
2667
    return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size)


def qr_destroy(fa: int) -> None:
    torch.ops._C_custom_ar.qr_destroy(fa)


2668
2669
2670
2671
2672
2673
2674
2675
def qr_all_reduce(
    fa: int,
    inp: torch.Tensor,
    out: torch.Tensor,
    quant_level: int,
    cast_bf2half: bool = False,
) -> None:
    torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689


def qr_get_handle(fa: int) -> torch.Tensor:
    return torch.ops._C_custom_ar.qr_get_handle(fa)


def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
    return torch.ops._C_custom_ar.qr_open_handles(fa, handles)


def qr_max_size() -> int:
    return torch.ops._C_custom_ar.qr_max_size()


2690
2691
2692
2693
def get_flash_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
2694
) -> tuple[torch.Tensor, torch.Tensor]:
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Return:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
2705
2706
2707
    return torch.ops._C.get_flash_mla_metadata(
        cache_seqlens, num_heads_per_head_k, num_heads_k
    )
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717


def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
2718
    softmax_scale: float | None = None,
2719
    causal: bool = False,
2720
) -> tuple[torch.Tensor, torch.Tensor]:
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head_dim of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Return:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
2738
        softmax_scale = q.shape[-1] ** (-0.5)
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
    out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
    )
    return out, softmax_lse
2752
2753


2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
def sm100_cutlass_mla_decode(
    out: torch.Tensor,
    lse: torch.Tensor,
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    seq_lens: torch.Tensor,
    page_table: torch.Tensor,
    workspace: torch.Tensor,
    scale: float,
    num_kv_splits: int,
) -> torch.Tensor:
    torch.ops._C.sm100_cutlass_mla_decode(
        out,
        lse,
        q_nope,
        q_pe,
        kv_c_and_k_pe_cache,
        seq_lens,
        page_table,
        workspace,
        scale,
        num_kv_splits,
    )
2778
2779
2780
    return out


2781
2782
2783
def sm100_cutlass_mla_get_workspace_size(
    max_seq_len: int, num_batches: int, sm_count: int, num_kv_splits: int
) -> int:
2784
    return torch.ops._C.sm100_cutlass_mla_get_workspace_size(
2785
2786
        max_seq_len, num_batches, sm_count, num_kv_splits
    )
2787
2788


2789
2790
2791
if hasattr(torch.ops._C, "weight_packed_linear"):

    @register_fake("_C::weight_packed_linear")
2792
2793
2794
    def weight_packed_linear_fake(
        mat1: torch.Tensor,
        mat2: torch.Tensor,
2795
        bias: torch.Tensor | None,
2796
2797
2798
2799
2800
        is_vnni: bool,
    ) -> torch.Tensor:
        return torch.empty(
            (mat1.size(0), mat2.size(0)), dtype=mat1.dtype, device=mat2.device
        )
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814


if hasattr(torch.ops._C, "fused_experts_cpu"):

    @register_fake("_C::fused_experts_cpu")
    def fused_experts_cpu_fake(
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        inplace: bool,
        use_int8_w8a8: bool,
        use_fp8_w8a16: bool,
2815
2816
2817
2818
2819
        w1_scale: torch.Tensor | None,
        w2_scale: torch.Tensor | None,
        block_size: list[int] | None,
        a1_scale: torch.Tensor | None,
        a2_scale: torch.Tensor | None,
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
        is_vnni: bool,
    ) -> torch.Tensor:
        return torch.empty_like(hidden_states)


if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):

    @register_fake("_C::int8_scaled_mm_with_quant")
    def int8_scaled_mm_with_quant_fake(
        mat1: torch.Tensor,
        mat2: torch.Tensor,
        scales2: torch.Tensor,
2832
        bias: torch.Tensor | None,
2833
2834
2835
2836
2837
2838
        out_dtype: torch.dtype,
        is_vnni: bool,
    ) -> torch.Tensor:
        M = mat1.size(0)
        N = mat2.size(0)
        return torch.empty((M, N), dtype=out_dtype)
2839
2840
2841
2842


class CPUDNNLGEMMHandler:
    def __init__(self) -> None:
2843
        self.handler: int | None = None
2844
2845
2846
2847
2848
2849
2850
2851
        self.n = -1
        self.k = -1

    def __del__(self):
        if self.handler is not None:
            torch.ops._C.release_dnnl_matmul_handler(self.handler)


2852
_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
2853
2854


2855
2856
2857
2858
def is_onednn_acl_supported():
    return torch.ops._C.is_onednn_acl_supported()


2859
2860
2861
2862
2863
2864
2865
def create_onednn_mm(
    weight: torch.Tensor,  # [K, N]
    primitive_cache_size: int = 128,
) -> CPUDNNLGEMMHandler:
    handler = CPUDNNLGEMMHandler()
    handler.k, handler.n = weight.size()
    handler.handler = torch.ops._C.create_onednn_mm_handler(
2866
2867
        weight, primitive_cache_size
    )
2868
2869
2870
2871
2872
2873
    return handler


def onednn_mm(
    dnnl_handler: CPUDNNLGEMMHandler,
    x: torch.Tensor,
2874
    bias: torch.Tensor | None,
2875
2876
) -> torch.Tensor:
    output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
2877
2878
2879
    torch.ops._C.onednn_mm(
        output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler
    )
2880
2881
2882
2883

    return output


2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
def create_onednn_scaled_mm(
    weight: torch.Tensor,  # [K, N]
    weight_scales: torch.Tensor,
    output_type: torch.dtype,
    dynamic_quant: bool,
    use_azp: bool,
    primitive_cache_size: int = 128,
) -> CPUDNNLGEMMHandler:
    handler = CPUDNNLGEMMHandler()
    handler.k, handler.n = weight.size()
    handler.handler = torch.ops._C.create_onednn_scaled_mm_handler(
2895
2896
        weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size
    )
2897
2898
2899
    return handler


2900
2901
def onednn_scaled_int8_quant(
    input: torch.Tensor,
2902
2903
    scale: torch.Tensor | None = None,
    azp: torch.Tensor | None = None,
2904
2905
    symmetric: bool = True,
):
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
    """
    Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.

    Args:
        input: The input tensor to be quantized to int8.
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
        azp: Optional zero-point for the int8 quantization.
            Must be provided for asymmetric quantization if `scale` is provided.
        symmetric: Whether to use symmetric quantization (scale only, azp ignored).

    Returns:
2918
      tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
2919
2920
2921
2922
2923
2924
    """
    output = torch.empty_like(input, dtype=torch.int8)
    token_num = input.numel() // input.shape[-1]
    input = input.view((token_num, input.shape[-1]))
    if scale is not None:
        # static-per-tensor quantization.
2925
2926
2927
        assert symmetric == (azp is None), (
            "azp must only be provided for asymmetric quantization."
        )
2928
2929
2930
2931
        torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
        return output, scale, azp

    # dynamic-per-token quantization.
2932
2933
2934
    input_scales = torch.empty((token_num, 1), device=input.device, dtype=torch.float32)
    input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
2935
2936
2937
2938
2939
2940
2941
    return output, input_scales, input_azp


def onednn_scaled_mm(
    dnnl_handler: CPUDNNLGEMMHandler,
    x: torch.Tensor,
    output: torch.Tensor,
2942
2943
2944
2945
    input_scale: torch.Tensor | None,
    input_zp: torch.Tensor | None,
    input_zp_adj: torch.Tensor | None,
    bias: torch.Tensor | None,
2946
) -> torch.Tensor:
2947
2948
2949
    torch.ops._C.onednn_scaled_mm(
        output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler
    )
2950
2951

    return output
2952
2953


2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
def cpu_attn_get_scheduler_metadata(
    num_reqs: int,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
    seq_lens: torch.Tensor,
    dtype: torch.dtype,
    query_start_loc: torch.Tensor,
    causal: bool,
    sliding_window_size: int,
    isa: str,
    enable_kv_split: bool,
) -> torch.Tensor:
    sheduler_metadata = torch.ops._C.get_scheduler_metadata(
        num_reqs,
        num_heads,
        num_kv_heads,
        head_dim,
        seq_lens,
        dtype,
        query_start_loc,
        causal,
        sliding_window_size,
        isa,
        enable_kv_split,
    )
    return sheduler_metadata


def cpu_attn_reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    isa: str,
) -> None:
    torch.ops._C.cpu_attn_reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        isa,
    )


def cpu_attention_with_kv_cache(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    output: torch.Tensor,
    query_start_loc: torch.Tensor,
    seq_lens: torch.Tensor,
    scale: float,
    causal: bool,
    alibi_slopes: torch.Tensor | None,
    sliding_window: tuple[int, int],
    block_table: torch.Tensor,
    softcap: float,
    scheduler_metadata: torch.Tensor,
    s_aux: torch.Tensor | None,
) -> None:
    torch.ops._C.cpu_attention_with_kv_cache(
        query,
        key_cache,
        value_cache,
        output,
        query_start_loc,
        seq_lens,
        scale,
        causal,
        alibi_slopes,
        sliding_window[0],
        sliding_window[1],
        block_table,
        softcap,
        scheduler_metadata,
        s_aux,
    )


Li, Jiang's avatar
Li, Jiang committed
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
def cpu_gemm_wna16(
    input: torch.Tensor,
    q_weight: torch.Tensor,
    scales: torch.Tensor,
    zeros: torch.Tensor | None,
    g_idx: torch.Tensor | None,
    bias: torch.Tensor | None,
    pack_factor: int,
    isa_hint: str,
) -> torch.Tensor:
    output = torch.empty((input.size(0), scales.size(1)), dtype=input.dtype)
    torch.ops._C.cpu_gemm_wna16(
        input,
        q_weight,
        output,
        scales,
        zeros,
        g_idx,
        bias,
        pack_factor,
        isa_hint,
    )
    return output


3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
def cpu_prepack_moe_weight(
    weight: torch.Tensor,
    isa: str,
) -> torch.Tensor:
    output = torch.empty_like(weight)
    torch.ops._C.prepack_moe_weight(weight, output, isa)
    return output


def cpu_fused_moe(
    input: torch.Tensor,
    w13: torch.Tensor,
    w2: torch.Tensor,
    w13_bias: torch.Tensor | None,
    w2_bias: torch.Tensor | None,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    act: str,
    isa: str,
) -> torch.Tensor:
    output = torch.empty_like(input)
    torch.ops._C.cpu_fused_moe(
        output,
        input,
        w13,
        w2,
        w13_bias,
        w2_bias,
        topk_weights,
        topk_ids,
        act,
        isa,
    )
    return output


3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):

    @register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
    def _fake_matmul_mxf4_bf16_tn(
        a: torch.Tensor,
        b: torch.Tensor,
        a_sf: torch.Tensor,
        b_sf: torch.Tensor,
        alpha: torch.Tensor,
    ):
        return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16)


def matmul_mxf4_bf16_tn(
    a: torch.Tensor,
    b: torch.Tensor,
    a_sf: torch.Tensor,
    b_sf: torch.Tensor,
    alpha: torch.Tensor,
) -> torch.Tensor:
    return torch.ops._qutlass_C.matmul_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha)


if hasattr(torch.ops._qutlass_C, "matmul_ada_mxf4_bf16_tn"):

    @register_fake("_qutlass_C::matmul_ada_mxf4_bf16_tn")
    def _fake_matmul_ada_mxf4_bf16_tn(
        a: torch.Tensor,
        b: torch.Tensor,
        a_sf: torch.Tensor,
        b_sf: torch.Tensor,
        alpha: torch.Tensor,
    ):
        return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16)


def matmul_ada_mxf4_bf16_tn(
    a: torch.Tensor,
    b: torch.Tensor,
    a_sf: torch.Tensor,
    b_sf: torch.Tensor,
    alpha: torch.Tensor,
) -> torch.Tensor:
    return torch.ops._qutlass_C.matmul_ada_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha)


if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxQuest"):

    @register_fake("_qutlass_C::fusedQuantizeMxQuest")
    def _fake_fused_quantize_mx_quest(
        a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor
    ):
        return xh_e2m1, xh_e8m0


if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxAbsMax"):

    @register_fake("_qutlass_C::fusedQuantizeMxAbsMax")
    def _fake_fused_quantize_mx_absmax(
        a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor
    ):
        return xh_e2m1, xh_e8m0


def fusedQuantizeMx(
    a: torch.Tensor, b: torch.Tensor, *, method: Literal["quest", "abs_max"] = "quest"
) -> tuple[torch.Tensor, torch.Tensor]:
    if a.dim() == 0:
        raise ValueError("`a` must have at least 1 dimension.")
    if a.size(-1) % 32 != 0:
        raise ValueError(f"last dim of `a` must be divisible by 32, got {a.size(-1)}.")
    if b.device != a.device:
        raise ValueError("`a` and `b` must be on the same device.")

    xh_e2m1 = torch.empty(
        *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device
    )

    rows, cols = a.numel() // a.size(-1), a.size(-1) // 32
3176
3177
    n_row_blocks = cdiv(rows, 128)
    n_col_blocks = cdiv(cols, 4)
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
    padded_rows = n_row_blocks * 128
    padded_cols = n_col_blocks * 4

    xh_e8m0 = torch.empty(
        padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device
    )

    if not hasattr(torch.ops, "_qutlass_C"):
        raise RuntimeError(
            "The `_qutlass_C` extension is not loaded. "
            "Make sure your custom op library is imported before calling fusedQuantizeMx."
        )

    if method == "quest":
        return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0)
    elif method == "abs_max":
        return torch.ops._qutlass_C.fusedQuantizeMxAbsMax(a, b, xh_e2m1, xh_e8m0)
    else:
        raise ValueError(f"invalid method {method!r}, must be 'quest' or 'abs_max'")


if hasattr(torch.ops._qutlass_C, "fusedQuantizeNv"):

    @register_fake("_qutlass_C::fusedQuantizeNv")
    def _fake_fused_quantize_nv(
        a: torch.Tensor,
        b: torch.Tensor,
        xh_e2m1: torch.Tensor,
        xh_e4m3: torch.Tensor,
        global_scale: torch.Tensor,
    ):
        return xh_e2m1, xh_e4m3


def fusedQuantizeNv(
    a: torch.Tensor, b: torch.Tensor, global_scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    xh_e2m1 = torch.empty(
        *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device
    )

    rows, cols = a.numel() // a.size(-1), a.size(-1) // 16
3220
3221
    n_row_blocks = cdiv(rows, 128)
    n_col_blocks = cdiv(cols, 4)
3222
3223
3224
3225
3226
3227
3228
3229
3230
    padded_rows = n_row_blocks * 128
    padded_cols = n_col_blocks * 4
    xh_e4m3 = torch.empty(
        padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=a.device
    )

    return torch.ops._qutlass_C.fusedQuantizeNv(a, b, xh_e2m1, xh_e4m3, global_scale)


3231
3232
3233
3234
3235
3236
3237
3238
def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
    """
    Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832)
    kernels. Note that these kernels exploit the recursive properties of
    Sylvester Hadamards, and therefore do not require transform weight data

    Note that sylvester hadamard transforms are also symmetric, which means that
    this function is also applies the (transpose <=> inverse) transform.
3239

3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
    :param x: value to be transformed inplace
    :param inplace: modify value in place
    :return: value after transformation
    """
    return torch.ops._C.hadacore_transform(x, inplace)


if hasattr(torch.ops._C, "hadacore_transform"):

    @register_fake("_C::hadacore_transform")
3250
    def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor:
3251
        return torch.empty_like(x) if not inplace else x