_custom_ops.py 77.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
from typing import TYPE_CHECKING, Optional, Union
6
7
8

import torch

9
import vllm.envs as envs
10
from vllm.logger import init_logger
11
from vllm.platforms import current_platform
12
from vllm.scalar_type import ScalarType
13
14
15

logger = init_logger(__name__)

16
17
if not current_platform.is_tpu() and not current_platform.is_hpu()\
        and not current_platform.is_xpu():
18
19
20
21
    try:
        import vllm._C
    except ImportError as e:
        logger.warning("Failed to import from vllm._C with %r", e)
22

23
supports_moe_ops = False
24
with contextlib.suppress(ImportError):
25
    import vllm._moe_C  # noqa: F401
26
    supports_moe_ops = True
27

28
if TYPE_CHECKING:
29
30
31
32
33
34
35
36
37

    def register_fake(fn):
        return lambda name: fn
else:
    try:
        from torch.library import register_fake
    except ImportError:
        from torch.library import impl_abstract as register_fake

38
39
40
41
42
43
44
45
46
47

# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
48
    seq_lens: torch.Tensor,
49
    block_size: int,
50
    max_seq_len: int,
51
52
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
53
54
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
55
56
57
58
59
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
60
) -> None:
61
    torch.ops._C.paged_attention_v1(
62
63
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
64
65
66
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
67
68
69
70
71
72
73
74
75
76
77
78
79


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
80
    seq_lens: torch.Tensor,
81
    block_size: int,
82
    max_seq_len: int,
83
84
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
85
86
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
87
88
89
90
91
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
92
) -> None:
93
    torch.ops._C.paged_attention_v2(
94
95
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
96
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
97
98
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
99
100


101
102
103
104
105
106
107
108
109
110
111
112
def paged_attention_rocm(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
113
    query_start_loc: Optional[torch.Tensor],
114
115
116
117
    block_size: int,
    max_seq_len: int,
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
118
119
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
120
    fp8_out_scale: Optional[torch.Tensor] = None,
121
122
123
124
) -> None:
    torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
                                      key_cache, value_cache, num_kv_heads,
                                      scale, block_tables, seq_lens,
125
126
                                      query_start_loc, block_size, max_seq_len,
                                      alibi_slopes, kv_cache_dtype, k_scale,
127
                                      v_scale, fp8_out_scale)
128
129


Thien Tran's avatar
Thien Tran committed
130
131
132
133
134
135
136
137
138
139
140
141
def mla_decode_kvcache_cpu(
    out: torch.Tensor,
    query: torch.Tensor,
    kv_cache: torch.Tensor,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
) -> None:
    torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale,
                                        block_tables, seq_lens)


142
143
144
145
146
147
148
149
150
151
152
# merge attn states ops
def merge_attn_states(output: torch.Tensor,
                      prefix_output: torch.Tensor,
                      prefix_lse: torch.Tensor,
                      suffix_output: torch.Tensor,
                      suffix_lse: torch.Tensor,
                      output_lse: Optional[torch.Tensor] = None) -> None:
    torch.ops._C.merge_attn_states(output, output_lse, prefix_output,
                                   prefix_lse, suffix_output, suffix_lse)


153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def convert_vertical_slash_indexes(
    q_seqlens: torch.Tensor,  # [BATCH, ]
    kv_seqlens: torch.Tensor,  # [BATCH, ]
    vertical_indexes: torch.Tensor,  # [BATCH, N_HEADS, NNZ_V]
    slash_indexes: torch.Tensor,  # [BATCH, N_HEADS, NNZ_S]
    context_size: int,
    block_size_M: int,
    block_size_N: int,
    causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    batch_size = slash_indexes.size(0)
    num_heads = slash_indexes.size(1)
    nnz_slash = slash_indexes.size(2)
    nnz_vertical = vertical_indexes.size(2)
    num_rows = (context_size + block_size_M - 1) // block_size_M

    block_count = torch.zeros(batch_size,
                              num_heads,
                              num_rows,
                              dtype=q_seqlens.dtype,
                              device=q_seqlens.device)
    block_offset = torch.zeros(batch_size,
                               num_heads,
                               num_rows,
                               nnz_slash,
                               dtype=q_seqlens.dtype,
                               device=q_seqlens.device)
    column_count = torch.zeros(batch_size,
                               num_heads,
                               num_rows,
                               dtype=q_seqlens.dtype,
                               device=q_seqlens.device)
    column_index = torch.zeros(batch_size,
                               num_heads,
                               num_rows,
                               nnz_vertical,
                               dtype=q_seqlens.dtype,
                               device=q_seqlens.device)

    torch.ops._C.convert_vertical_slash_indexes(
        block_count, block_offset, column_count, column_index, q_seqlens,
        kv_seqlens, vertical_indexes, slash_indexes, context_size,
        block_size_M, block_size_N, causal)
    return block_count, block_offset, column_count, column_index


def convert_vertical_slash_indexes_mergehead(
    q_seqlens: torch.Tensor,  # [BATCH, ]
    kv_seqlens: torch.Tensor,  # [BATCH, ]
    vertical_indexes: torch.Tensor,  # [BATCH, N_HEADS, NNZ_V]
    slash_indexes: torch.Tensor,  # [BATCH, N_HEADS, NNZ_S]
    # [N_HEADS] : different head use different number of indices
    vertical_indices_count: torch.Tensor,
    slash_indices_count: torch.Tensor,
    context_size: int,
    block_size_M: int,
    block_size_N: int,
    causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    batch_size = slash_indexes.size(0)
    num_heads = slash_indexes.size(1)
    nnz_slash = slash_indexes.size(2)
    nnz_vertical = vertical_indexes.size(2)
    num_rows = (context_size + block_size_M - 1) // block_size_M

    block_count = torch.empty(batch_size,
                              num_heads,
                              num_rows,
                              dtype=q_seqlens.dtype,
                              device=q_seqlens.device)
    block_offset = torch.empty(batch_size,
                               num_heads,
                               num_rows,
                               nnz_slash,
                               dtype=q_seqlens.dtype,
                               device=q_seqlens.device)
    column_count = torch.empty(batch_size,
                               num_heads,
                               num_rows,
                               dtype=q_seqlens.dtype,
                               device=q_seqlens.device)
    column_index = torch.empty(batch_size,
                               num_heads,
                               num_rows,
                               nnz_vertical,
                               dtype=q_seqlens.dtype,
                               device=q_seqlens.device)

    torch.ops._C.convert_vertical_slash_indexes_mergehead(
        block_count, block_offset, column_count, column_index, q_seqlens,
        kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count,
        slash_indices_count, context_size, block_size_M, block_size_N, causal)
    return block_count, block_offset, column_count, column_index


248
249
250
251
# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
252
    key: Optional[torch.Tensor],
253
254
255
256
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
257
258
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
259
260
261


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
262
                             key: Optional[torch.Tensor], head_size: int,
263
264
265
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
266
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
267
268
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
269
270
271
272
273


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
274
275
276
    # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
    input_contiguous = input.contiguous()
    torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
277
278
279
280


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
281
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
282
283


284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def apply_repetition_penalties_torch(
        logits: torch.Tensor, prompt_mask: torch.Tensor,
        output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
    repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
        1, logits.size(1))
    # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
    penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
                            1.0)
    # If logits are positive, divide by penalty, otherwise multiply by penalty.
    scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
    logits *= scaling


def apply_repetition_penalties_cuda(
        logits: torch.Tensor, prompt_mask: torch.Tensor,
        output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
    torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask,
                                             repetition_penalties)


def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
                               output_mask: torch.Tensor,
                               repetition_penalties: torch.Tensor) -> None:
    """Apply repetition penalties to logits in-place.

    Args:
        logits: The logits tensor of shape [num_seqs, vocab_size].
        prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
        output_mask: A boolean tensor indicating which tokens appear in the output.
        repetition_penalties: The repetition penalties of shape (num_seqs, ).
    """
    if current_platform.is_cuda() and logits.is_contiguous():
        apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
                                        repetition_penalties)
    else:
        apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
                                         repetition_penalties)


323
324
325
326
327
328
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
                           input_tokens: torch.Tensor,
                           sampled_token_ids: torch.Tensor,
                           input_positions: torch.Tensor,
                           seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                           block_tables: torch.Tensor) -> None:
329
    """Advance a step on GPU for existing inputs for a multi-step runner"""
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
                                               block_size, input_tokens,
                                               sampled_token_ids,
                                               input_positions, seq_lens,
                                               slot_mapping, block_tables)


def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
                            input_tokens: torch.Tensor,
                            sampled_token_ids: torch.Tensor,
                            input_positions: torch.Tensor,
                            seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                            block_tables: torch.Tensor,
                            paged_kv_indices: torch.Tensor,
                            paged_kv_indptr: torch.Tensor,
                            paged_kv_last_page_len: torch.Tensor,
                            block_table_bound: torch.Tensor) -> None:

    return torch.ops._C.advance_step_flashinfer(
        num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
        input_positions, seq_lens, slot_mapping, block_tables,
        paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
        block_table_bound)
353
354


355
356
357
358
359
360
361
362
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    scale_ub: Optional[torch.Tensor] = None,
    residual: Optional[torch.Tensor] = None
363
) -> tuple[torch.Tensor, torch.Tensor]:
364
365
366
367
368
369
370
371
372
373
374
    output = torch.empty_like(input, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                         device=input.device,
                         dtype=torch.float32)

    torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight,
                                                  scales, epsilon, scale_ub,
                                                  residual)
    return output, scales


375
376
377
378
379
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
380
381
382
383
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_dequantize_triton)
        return awq_dequantize_triton(qweight, scales, zeros)
384
385
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
386
387
388
389


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
390
391
392
393
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_gemm_triton)
        return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
394
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
395
396
397
398
399
400
401


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
402
403
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
404
405


406
if hasattr(torch.ops._C, "gptq_gemm"):
407

408
    @register_fake("_C::gptq_gemm")
409
410
411
412
413
414
415
416
417
    def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_gptq_qzeros: torch.Tensor,
                        b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
                        use_exllama: bool, bit: int) -> torch.Tensor:
        return torch.empty((a.size(0), b_q_weight.size(1)),
                           dtype=a.dtype,
                           device=a.device)


418
419
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
420
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
421
422
423
424
425
426


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
427
428
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
429
430


431
432
433
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
434
435
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
436
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
437
                                            workspace, b_q_type.id, size_m,
438
                                            size_n, size_k)
439
440


441
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
442

443
    @register_fake("_C::gptq_marlin_24_gemm")
444
445
446
    def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                                  b_meta: torch.Tensor, b_scales: torch.Tensor,
                                  workspace: torch.Tensor,
447
448
449
                                  b_q_type: ScalarType, size_m: torch.SymInt,
                                  size_n: torch.SymInt,
                                  size_k: torch.SymInt) -> torch.Tensor:
450
451
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

452
    @register_fake("_C::gptq_marlin_gemm")
453
    def _gptq_marlin_gemm_fake(a: torch.Tensor,
454
                               c: Optional[torch.Tensor],
455
456
                               b_q_weight: torch.Tensor,
                               b_scales: torch.Tensor,
457
                               global_scale: Optional[torch.Tensor],
458
459
460
                               b_zeros: Optional[torch.Tensor],
                               g_idx: Optional[torch.Tensor],
                               perm: Optional[torch.Tensor],
461
                               workspace: torch.Tensor,
462
                               b_q_type_id: int,
463
464
465
                               size_m: torch.SymInt,
                               size_n: torch.SymInt,
                               size_k: torch.SymInt,
466
                               is_k_full: bool = True,
467
                               use_atomic_add: bool = False,
468
469
                               use_fp32_reduce: bool = False,
                               is_zp_float: bool = False) -> torch.Tensor:
470
471
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

472
    @register_fake("_C::marlin_qqq_gemm")
473
474
475
    def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              s_tok: torch.Tensor, s_ch: torch.Tensor,
                              s_group: torch.Tensor, workspace: torch.Tensor,
476
477
                              size_m: torch.SymInt, size_n: torch.SymInt,
                              size_k: torch.SymInt) -> torch.Tensor:
478
479
480
481
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

482
    @register_fake("_C::marlin_gemm")
483
484
    def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                          b_scales: torch.Tensor, workspace: torch.Tensor,
485
486
                          size_m: torch.SymInt, size_n: torch.SymInt,
                          size_k: torch.SymInt) -> torch.Tensor:
487
488
489
490
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

491
    @register_fake("_C::awq_dequantize")
492
    def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
493
494
                             zeros: torch.Tensor, split_k_iters: torch.SymInt,
                             thx: int, thy: int) -> torch.Tensor:
495
496
497
498
499
500
501
        in_c = qweight.size(0)
        qout_c = qweight.size(1)
        out_c = qout_c * 8
        return torch.empty((in_c, out_c),
                           dtype=scales.dtype,
                           device=scales.device)

502
    @register_fake("_C::awq_gemm")
503
504
    def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
                       qzeros: torch.Tensor, scales: torch.Tensor,
505
                       split_k_iters: torch.SymInt) -> torch.Tensor:
506
507
508
509
510
        num_in_feats = input.size(0)
        return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
                           dtype=input.dtype,
                           device=input.device).sum(0)

511
    @register_fake("_C::aqlm_gemm")
512
513
    def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
                        codebooks: torch.Tensor, scales: torch.Tensor,
514
                        codebook_partition_sizes: list[int],
515
516
517
518
519
520
521
522
523
524
525
526
                        bias: Optional[torch.Tensor]) -> torch.Tensor:
        out_features = codes.size(0) * codebooks.size(2)
        flat_input = input.reshape((-1, input.size(-1)))
        flat_output = torch.empty((flat_input.size(0), out_features),
                                  dtype=input.dtype,
                                  device=input.device)

        output_sizes = list(input.shape)
        output_sizes.pop()
        output_sizes.append(-1)
        return flat_output.reshape(tuple(output_sizes))

527
    @register_fake("_C::aqlm_dequant")
528
529
    def _aqlm_dequant_fake(
            codes: torch.Tensor, codebooks: torch.Tensor,
530
            codebook_partition_sizes: list[int]) -> torch.Tensor:
531
532
533
534
535
536
        in_features = codes.size(1) * 8
        out_features = codes.size(0)
        return torch.empty((out_features, in_features),
                           dtype=codebooks.dtype,
                           device=codebooks.device)

537
538
    @register_fake("_C::machete_mm")
    def machete_mm_fake(
539
        a: torch.Tensor,
540
        # b_q Should be the tensor returned by machete_prepack_B
541
        b_q: torch.Tensor,
542
        b_type: ScalarType,
543
544
545
        out_type: Optional[torch.dtype] = None,
        b_group_scales: Optional[torch.Tensor] = None,
        b_group_zeros: Optional[torch.Tensor] = None,
546
        b_group_size: Optional[int] = None,
547
548
        b_channel_scales: Optional[torch.Tensor] = None,
        a_token_scales: Optional[torch.Tensor] = None,
549
550
551
552
553
554
        schedule: Optional[str] = None,
    ) -> torch.Tensor:
        m = a.size(0)
        n = b_q.size(1)
        return torch.empty((m, n), device=a.device, dtype=a.dtype)

555
    @register_fake("_C::machete_prepack_B")
556
557
558
    def machete_prepack_B_fake(
            b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
            group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
559
560
        return torch.empty_like(b_q_weight,
                                memory_format=torch.contiguous_format)
561
562


563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):

    @register_fake("_C::allspark_w8a16_gemm")
    def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor,
                                  b_scales: torch.Tensor,
                                  b_qzeros: Optional[torch.Tensor],
                                  n: torch.SymInt, group_size: torch.SymInt,
                                  sm_count: torch.SymInt,
                                  sm_version: torch.SymInt,
                                  CUBLAS_M_THRESHOLD: torch.SymInt,
                                  has_zp: bool,
                                  n32k16_reorder: bool) -> torch.Tensor:
        m = a.size(0)
        return torch.empty((m, n), device=a.device, dtype=a.dtype)


579
580
581
if hasattr(torch.ops._C, "ggml_dequantize"):

    @register_fake("_C::ggml_dequantize")
582
583
584
585
586
587
    def _ggml_dequantize_fake(
            W: torch.Tensor,
            quant_type: int,
            m: torch.SymInt,
            n: torch.SymInt,
            dtype: Optional[torch.dtype] = None) -> torch.Tensor:
588
589
590
591
592
593
594
595
596
        return torch.empty((m, n), dtype=torch.float16, device=W.device)

    @register_fake("_C::ggml_mul_mat_vec_a8")
    def _ggml_mul_mat_vec_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: torch.SymInt,
    ) -> torch.Tensor:
597
        return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device)
598
599
600
601
602
603
604
605
606

    @register_fake("_C::ggml_mul_mat_a8")
    def _ggml_mul_mat_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: torch.SymInt,
    ) -> torch.Tensor:
        batch = X.size(0)
607
        return torch.empty((batch, row), dtype=X.dtype, device=W.device)
608

609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    @register_fake("_C::ggml_moe_a8")
    def _ggml_moe_a8_fake(
        X: torch.Tensor,
        W: torch.Tensor,
        sorted_token_ids: torch.Tensor,
        expert_ids: torch.Tensor,
        num_tokens_post_padded: torch.Tensor,
        quant_type: int,
        row: torch.SymInt,
        top_k: torch.SymInt,
        tokens: torch.SymInt,
    ) -> torch.Tensor:
        tokens = X.size(0)
        return torch.empty((tokens * top_k, row),
                           dtype=torch.float16,
                           device=W.device)

626

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
if hasattr(torch.ops._C, "ggml_moe_a8_vec"):

    @register_fake("_C::ggml_moe_a8_vec")
    def _ggml_moe_a8_vec_fake(
        X: torch.Tensor,
        W: torch.Tensor,
        topk_ids: torch.Tensor,
        top_k: int,
        quant_type: int,
        row: torch.SymInt,
        tokens: torch.SymInt,
    ) -> torch.Tensor:
        tokens = X.size(0)
        return torch.empty((tokens * top_k, row),
                           dtype=X.dtype,
                           device=W.device)


645
# cutlass
646
647
648
649
def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)


650
651
652
653
654
655
656
657
658
659
660
661
662
663
def cutlass_blockwise_scaled_grouped_mm(
    output: torch.Tensor,
    a: torch.Tensor,
    b: torch.Tensor,
    scales_a: torch.Tensor,
    scales_b: torch.Tensor,
    problem_sizes: torch.Tensor,
    expert_offsets: torch.Tensor,
):
    torch.ops._C.cutlass_blockwise_scaled_grouped_mm(output, a, b, scales_a,
                                                     scales_b, problem_sizes,
                                                     expert_offsets)


664
665
666
667
668
669
670
671
672
673
674
675
def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
                          block_scale_a: torch.Tensor,
                          block_scale_b: torch.Tensor, alpha: torch.Tensor,
                          out_dtype: torch.dtype) -> torch.Tensor:
    assert a.ndim == 2 and b.ndim == 2
    m, n = a.shape[0], b.shape[0]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)
    torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b,
                                       alpha)
    return out


676
677
678
679
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


680
681
682
683
684
def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(
        cuda_device_capability)


685
686
687
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
688
                      scale_b: torch.Tensor,
689
                      out_dtype: torch.dtype,
690
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
691
    """
692
    `cutlass_scaled_mm` implements a fused version of
693
        `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
694
695
696
697
698
699
700
701
    where scale_a * a and scale_b * b are implemented using numpy-style
    broadcasting.

    In order to support blockwise scaling like found in DeepSeek V3 we also
    support extended "group" broadcast rules. We extend the numpy-style
    broadcasting rules with the following rule:
        "if the extent of a dimension in the source shape is between 1 and
        corresponding extent in the target shape we repeat each element along
702
703
704
705
706
707
708
709
710
711
712
        that dimension  src_shape[dim] // target_shape[dim] times consecutively"
    example if we have:
          a = [[1, 2], and target_shape = (2, 4)
               [3, 4]]
    then we would expand a to:
          a = [[1, 1, 2, 2],
               [3, 3, 4, 4]]
    currently we only support the case:
        scale_a.shape * [1, 128] == a.shape
        scale_b.shape * [128, 128] == b.shape
    """
713
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
714
715
    assert bias is None or bias.shape[0] == b.shape[
        1] and bias.dtype == out_dtype
716
717
718

    m = a.shape[0]
    n = b.shape[1]
719

720
721
    cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    if current_platform.is_rocm() or not cutlass_compatible_b:
722
723
        from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import (  # noqa
            triton_scaled_mm)
724
725
        return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

726
727
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

728
729
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

730
731
732
    return out


733
734
735
736
737
738
739
740
def cutlass_scaled_mm_azp(a: torch.Tensor,
                          b: torch.Tensor,
                          scale_a: torch.Tensor,
                          scale_b: torch.Tensor,
                          out_dtype: torch.dtype,
                          azp_adj: torch.Tensor,
                          azp: Optional[torch.Tensor] = None,
                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
741
742
743
744
745
    """
    :param azp_adj: In the per-tensor case, this should include the azp.
    Always per-channel.
    :param azp: Only set in the per-token case. Per-token if set.
    """
746
747
748
749
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.numel(
    ) == b.shape[1] and bias.dtype == out_dtype
750
    assert azp is None or azp.numel() == a.shape[0]
751
752
753
754
755
756
757
758
759
760

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
                                       azp, bias)
    return out


761
762
763
764
765
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_sparse_scaled_mm_supported(
        cuda_device_capability)


766
767
768
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)

769
def cutlass_sparse_compress(a: torch.Tensor) \
770
    -> tuple[torch.Tensor, torch.Tensor]:
771
772
773
774
775
776
777
778
    """
    Compresses a sparse matrix for use with Cutlass sparse operations.

    This function takes a dense tensor and compresses it into two components:
    non-zero elements and metadata. The compressed representation is compatible
    with Cutlass sparse kernels.

    Args:
779
        a (torch.Tensor):
780
781
782
783
784
785
786
            The input tensor to be compressed. Must have one of the following data types:
            - `torch.int8`
            - `torch.float8_e4m3fn`
            - `torch.bfloat16`
            - `torch.float16`

    Returns:
787
        tuple[torch.Tensor, torch.Tensor]:
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
            A tuple containing:
            - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
            - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.

    Raises:
        ValueError: If the compression operation fails.

    Notes:
        - The `a_meta` tensor has a data type of `torch.uint8`.
        - Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
        - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
        - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
    """
    assert (a.dtype in [
        torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
    ])
    assert (a.is_contiguous())

    # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
    elemsPerMetaElem = 4
808
    assert (a.shape[1] % (2 * elemsPerMetaElem) == 0)
809

810
    return torch.ops._C.cutlass_sparse_compress(a)
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858


def cutlass_scaled_sparse_mm(
        a: torch.Tensor,
        bt_nzs: torch.Tensor,
        bt_meta: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
        out_dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Performs a scaled sparse matrix multiplication using Cutlass.

    Steps:
    1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
    `a = torch.randn((m, k), device='cuda')`.

    2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
    `b = torch.randn((k, n), device='cuda')`.

    3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
    `b = prune_to_2_4(b, dim=0)`.

    4. Compress the transposed sparse matrix `b.t()`:
    `bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.

    5. Perform sparse matrix multiplication using the compressed matrix,
    applying scaling factors for `a` and `b`, and the output data type:
    `out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.

    Returns:
    - The result of the scaled sparse matrix multiplication.
    """
    assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.shape[0] == bt_nzs.shape[0] \
        and bias.dtype == out_dtype

    m = a.shape[0]
    n = bt_nzs.shape[0]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
                                          scale_b, bias)

    return out


859
860
861
862
863
864
865
866
867
868
def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
                            expert_offsets: torch.Tensor,
                            problem_sizes1: torch.Tensor,
                            problem_sizes2: torch.Tensor,
                            input_permutation: torch.Tensor,
                            output_permutation: torch.Tensor,
                            num_experts: int,
                            n: int,
                            k: int,
                            blockscale_offsets: Optional[torch.Tensor] = None):
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
    """
    Prepare data necessary to perform CUTLASS grouped matrix multiplications
    used in CUTLASS-based fused MoE.

    The function takes in topk_ids (token-expert mapping) and uses it to
    compute:
    - expert_offsets: Indices that mark at which token index each expert begins
                      its computation after the input is sorted with
                      input_permutation. The number of tokens computed with
                      expert E is expert_offsets[E + 1] - expert_offsets[E]
    - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
                                      multiplication in two grouped MMs used in
                                      the fused MoE operation.
    - input_permutation: Permutation that must be used to shuffle the input
                         before executing the MMs.
    - output_permutation: Permutation that must be used to shuffle the output
                          after executing the MMs.
886
887
888
889
890
    - blockscale_offsets: Optional argument passed for fp4 moe. Indices that
                          mark at which block scale index each expert begins
                          its computation. The number of block scale rows
                          computed with expert E is blockscale_offsets[E + 1] -
                          blockscale_offsets[E]
891
    """
892
893
894
895
    return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
                                                problem_sizes1, problem_sizes2,
                                                input_permutation,
                                                output_permutation,
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
                                                num_experts, n, k,
                                                blockscale_offsets)


def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
    """
    Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
    This is used in MoE to permute the input tensor before performing grouped matrix multiplications.
    """
    num_tokens_permuted = dst2src_map.shape[0]
    output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]),
                                device=input_tensor.device,
                                dtype=input_tensor.dtype)
    torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor)
    return output_tensor
911
912


913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor,
                                 problem_sizes1: torch.Tensor,
                                 problem_sizes2: torch.Tensor,
                                 expert_num_tokens: torch.Tensor,
                                 num_local_experts: int, padded_m: int, n: int,
                                 k: int):
    """
    Prepare data necessary to perform CUTLASS grouped matrix multiplications
    used in CUTLASS-based fused MoE.

    The function takes in expert_num_tokens (token count per expert) and
    non_zero_expert_idxs (consecutive indices of experts with non-zero token 
    counts) and uses them to compute:
    - expert_offsets: Indices that mark at which token index each expert begins
                      its computation.
    - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
                                      multiplication in two grouped MMs used in
                                      the fused MoE operation.
    """
    return torch.ops._C.get_cutlass_pplx_moe_mm_data(
        expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens,
        num_local_experts, padded_m, n, k)


937
938
939
940
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
                   b_tensors: torch.Tensor, a_scales: torch.Tensor,
                   b_scales: torch.Tensor, expert_offsets: torch.Tensor,
                   problem_sizes: torch.Tensor, a_strides: torch.Tensor,
941
942
                   b_strides: torch.Tensor, c_strides: torch.Tensor,
                   per_act_token: bool, per_out_ch: bool):
943
944
945
946
947
948
949
950
951
952
953
    """
    A single grouped matrix multiplication used in CUTLASS-based fused MoE.
    The function executes fp8-quantized OUT = AB matrix multiplication.

    - expert_offsets: Indices that mark at which token index each expert begins
                      its computation. The number of tokens computed with
                      expert E is expert_offsets[E + 1] - expert_offsets[E]
    - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
                     MMs used in the fused MoE operation.
    - a/b/c_strides: The data strides passed to grouped matrix multiplication.
    """
954
955
956
    return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
                                       a_scales, b_scales, expert_offsets,
                                       problem_sizes, a_strides, b_strides,
957
                                       c_strides, per_act_token, per_out_ch)
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988


def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
                       a_scales: torch.Tensor, b_scales: torch.Tensor,
                       alphas: torch.Tensor, problem_sizes: torch.Tensor,
                       expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
                       out_dtype: torch.dtype, device: torch.device):
    """
    An FP4 Blockscaled Group Gemm that takes in  a_tensors, b_tensors and runs 
    the gemms for each combination based on the specified problem sizes.

    This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
    - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
                     input and expert weights.
    - a_/b_scales: The blockscales in FP8-E4M3 precision
    - expert_offsets/sf_offsets: Indices that mark at which token index 
                    each expert begins its computation. The number of tokens 
                    computed with expert E is expert_offsets[E + 1] - 
                    expert_offsets[E] And the sf_size per expert is 
                    sf_offset[E+1] - sf_offset[E]
    - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
                     MMs used in the fused MoE operation.
    """
    m_topk = a_tensors.shape[0]
    n = b_tensors.shape[1]
    c_shape = (m_topk, n)
    c = torch.empty(c_shape, device=device, dtype=out_dtype)
    torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
                                      b_scales, alphas, problem_sizes,
                                      expert_offsets, sf_offsets)
    return c.to(out_dtype)
989
990


991
992
993
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
994
              codebook_partition_sizes: list[int],
995
              bias: Optional[torch.Tensor]) -> torch.Tensor:
996
997
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
998
999
1000


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
1001
                 codebook_partition_sizes: list[int]) -> torch.Tensor:
1002
1003
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
1004
1005


1006
1007
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
1008
1009
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
1010
1011
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
1012
1013


1014
1015
1016
1017
1018
1019
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


1020
1021
1022
1023
1024
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                           size_k: int, size_n: int,
                           num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
1025
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
1026
1027
1028
1029
1030
1031
1032
1033
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
                                                    size_k, size_n, num_bits)
    return output


1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                          size_k: int, size_n: int,
                          num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
                                                   size_n, num_bits)
    return output


1048
def gptq_marlin_gemm(a: torch.Tensor,
1049
                     c: Optional[torch.Tensor],
1050
1051
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
1052
                     global_scale: Optional[torch.Tensor],
1053
1054
1055
                     b_zeros: Optional[torch.Tensor],
                     g_idx: Optional[torch.Tensor],
                     perm: Optional[torch.Tensor],
1056
1057
1058
1059
1060
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
1061
                     is_k_full: bool = True,
1062
                     use_atomic_add: bool = False,
1063
1064
                     use_fp32_reduce: bool = False,
                     is_zp_float: bool = False) -> torch.Tensor:
1065
1066
1067
1068
    return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
                                         global_scale, b_zeros, g_idx, perm,
                                         workspace, b_q_type.id, size_m,
                                         size_n, size_k, is_k_full,
1069
1070
                                         use_atomic_add, use_fp32_reduce,
                                         is_zp_float)
1071
1072


1073
# machete
1074
1075
1076
1077
1078
1079
1080
def machete_supported_schedules(
        a_type: torch.dtype,
        b_type: ScalarType,
        group_scales_type: Optional[torch.dtype],
        group_zeros_type: Optional[torch.dtype] = None,
        channel_scales_type: Optional[torch.dtype] = None,
        token_scales_type: Optional[torch.dtype] = None,
1081
        out_type: Optional[torch.dtype] = None) -> list[str]:
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
    return torch.ops._C.machete_supported_schedules(
        a_type, b_type.id, group_scales_type, group_zeros_type,
        channel_scales_type, token_scales_type, out_type)


def machete_mm(
        a: torch.Tensor,
        # b_q Should be the tensor returned by machete_prepack_B
        b_q: torch.Tensor,
        b_type: ScalarType,
        out_type: Optional[torch.dtype] = None,
        b_group_scales: Optional[torch.Tensor] = None,
        b_group_zeros: Optional[torch.Tensor] = None,
        b_group_size: Optional[int] = None,
        b_channel_scales: Optional[torch.Tensor] = None,
        a_token_scales: Optional[torch.Tensor] = None,
        schedule: Optional[str] = None) -> torch.Tensor:
    return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
                                   b_group_zeros, b_group_size,
                                   b_channel_scales, a_token_scales, schedule)


def machete_prepack_B(
        b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
        group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
    return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id,
                                          group_scales_type)
1109
1110


1111
if hasattr(torch.ops._C, "permute_cols"):
1112

1113
    @register_fake("_C::permute_cols")
1114
1115
1116
1117
1118
1119
1120
1121
1122
    def _permute_cols_fake(a: torch.Tensor,
                           perm: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(a)


def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
    return torch.ops._C.permute_cols(a, perm)


1123
1124
1125
# fp4
def scaled_fp4_quant(
        input: torch.Tensor,
1126
        input_global_scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    """
    Quantize input tensor to FP4 and return quantized tensor and scale.

    This function quantizes the last dimension of the given tensor `input`. For
    every 16 consecutive elements, a single dynamically computed scaling factor
    is shared. This scaling factor is quantized using the `input_global_scale`
    and is stored in a swizzled layout (see
    https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).

    Args:
        input: The input tensor to be quantized to FP4
        input_global_scale: A scalar scaling factor for the entire tensor.

    Returns:
1141
        tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
1142
1143
1144
            two values are packed into a uint8 and float8_e4m3 scaling factors
            in the sizzled layout.
    """
1145
    assert not current_platform.is_rocm()
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
    assert input.ndim >= 1, (
        f'input.ndim needs to be >= 1, but got {input.ndim}.')
    other_dims = 1 if input.ndim == 1 else -1
    input = input.reshape(other_dims, input.shape[-1])
    m, n = input.shape
    block_size = 16
    device = input.device

    assert n % block_size == 0, (
        f'last dim has to be multiple of 16, but got {n}.')
    assert input.dtype in (torch.float16, torch.bfloat16), (
        f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.')

    # Two fp4 values will be packed into an uint8.
    output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)

    # We use the rounded values to store the swizzled values. Due to the
    # requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
    # So, we first pad the scales to multiples of 128 and 4. Then, the scales
    # (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
    # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
    round_up = lambda x, y: (x + y - 1) // y * y
    rounded_m = round_up(m, 128)
    scale_n = n // block_size
    rounded_n = round_up(scale_n, 4)
    output_scale = torch.empty((rounded_m, rounded_n // 4),
                               device=device,
                               dtype=torch.int32)

    torch.ops._C.scaled_fp4_quant(output, input, output_scale,
                                  input_global_scale)
    output_scale = output_scale.view(torch.float8_e4m3fn)
    return output, output_scale


1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
def scaled_fp4_experts_quant(
    input_tensor: torch.Tensor,
    input_global_scale: torch.Tensor,
    expert_offsets: torch.Tensor,
    blockscale_offsets: torch.Tensor,
    topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Quantize input tensor to FP4 and return quantized tensor and scale, for
    packed MoE Inputs.
    Args:
1192
        input_tensor: The input tensor to be quantized to FP4
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
        input_global_scale: A scalar scaling factor for the entire tensor.
        expert_offsets: The expert offsets tensor
        blockscale_offsets: The blockscale offsets tensor
    Outputs:
        output: The quantized tensor in FP4
        output_scales: The blockscale tensor in FP8-E4M3
    """
    assert not current_platform.is_rocm()
    assert input_tensor.ndim == 2, (
        f'input.ndim needs to be == 2, but got {input_tensor.ndim}.')

1204
1205
1206
1207
1208
    # Control the maximum number of tokens per expert supported by the
    # NVFP4 MoE Expert Quantization. This is used to prevent the kernel
    # from running out of memory. This value can also be increased to support
    # larger models.
    MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
1209
1210
    m_numtopk, k = input_tensor.shape

1211
    assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
1212
1213
1214
1215
        f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
        f"{MAX_TOKENS_PER_EXPERT})"
        f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
        f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.")
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
    scales_k = k // 16
    padded_k = (scales_k + (4 - 1)) // 4

    # output is uint8 and packed fp4 values
    output = torch.empty(m_numtopk,
                         k // 2,
                         device=input_tensor.device,
                         dtype=torch.uint8)
    output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk,
                                padded_k,
                                dtype=torch.int32,
                                device=input_tensor.device)
    torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor,
                                          input_global_scale, expert_offsets,
                                          blockscale_offsets)
    output_scales = output_scales.view(torch.float8_e4m3fn)
    return output, output_scales


1235
# fp8
1236
1237
1238
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
1239
    num_token_padding: Optional[int] = None,
1240
    scale_ub: Optional[torch.Tensor] = None,
1241
    use_per_token_if_dynamic: bool = False,
1242
    output: Optional[torch.Tensor] = None,
1243
) -> tuple[torch.Tensor, torch.Tensor]:
1244
1245
1246
1247
1248
1249
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
1250
    optional padding of the output tensors for downstream kernels that
1251
1252
1253
1254
1255
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
1256
        scale_ub: Optional upper bound for scaling factor in dynamic
1257
            per token case
1258
        num_token_padding: If specified, pad the first dimension
1259
            of the output to at least this value.
1260
        use_per_token_if_dynamic: Whether to do per_tensor or per_token
1261
            in the dynamic quantization case.
1262
1263

    Returns:
1264
        tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
1265
1266
            scaling factor.
    """
1267
1268
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
1269
    shape: Union[tuple[int, int], torch.Size] = input.shape
1270
1271
    # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
    out_dtype: torch.dtype = current_platform.fp8_dtype()
1272
1273
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
1274
1275
1276
1277
1278
1279
    if output is None:
        output = torch.empty(shape, device=input.device, dtype=out_dtype)
    else:
        assert num_token_padding is None, \
            "padding not supported if output passed in"
        assert output.dtype == out_dtype
1280

1281
    if scale is None:
1282
        if use_per_token_if_dynamic:
1283
            scale = torch.empty((shape[0], 1),
1284
1285
1286
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
1287
                output, input.contiguous(), scale, scale_ub)
1288
1289
1290
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
1291
    else:
bnellnm's avatar
bnellnm committed
1292
        assert scale.numel() == 1, f"{scale.shape}"
1293
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
1294

1295
    return output, scale
1296
1297


1298
1299
1300
1301
1302
1303
# gptq allspark
def allspark_repack_weight(
        qweight: torch.Tensor,
        scale: torch.Tensor,
        zero_point: Optional[torch.Tensor] = None,
        has_zp: bool = False
1304
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1305
    """
1306
    Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
1307
1308
1309
1310
1311
1312
1313
1314
    for Ampere W8A16 Fused Gemm kernel

    Args:
        qweight: uint8 weight tensor, original k x n format.
        scale: fp16/bf16 weight scale tensor, 1 x n format.
        zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
            Must be provided for asymmetric quantization.
        has_zp: if use symmetric quantization, has_zp = False.
1315
1316
            if use asymmetric quantization, has_zp = True.

1317
    Returns:
1318
        tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
            rearranged weight, scale, and optionally zero_point.
    """
    K = qweight.shape[0]
    N = qweight.shape[1]
    N_32align = (N + 32 - 1) // 32 * 32

    qweight_reorder = torch.empty((N_32align, K),
                                  device=qweight.device,
                                  dtype=qweight.dtype)
    scale_reorder = torch.empty((1, N_32align),
                                device=scale.device,
                                dtype=scale.dtype)
    zero_point_reorder = None
    if has_zp:
        assert zero_point is not None, (
            "zero_point must be provided for asymmetric quantization.")
        zero_point_reorder = torch.empty((1, N_32align),
                                         device=zero_point.device,
                                         dtype=zero_point.dtype)

    torch.ops._C.rearrange_kn_weight_as_n32k16_order(
        qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder,
        zero_point_reorder, K, N, N_32align)

    return qweight_reorder, scale_reorder, zero_point_reorder


def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor,
                        b_scales: torch.Tensor,
                        b_qzeros: Optional[torch.Tensor], n: int,
                        group_size: int, sm_count: int, sm_version: int,
                        CUBLAS_M_THRESHOLD: int, has_zp: bool,
                        n32k16_reorder: bool) -> torch.Tensor:

    return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros,
                                            n, group_size, sm_count,
                                            sm_version, CUBLAS_M_THRESHOLD,
                                            has_zp, n32k16_reorder)


1359
# int8
1360
def scaled_int8_quant(
1361
1362
1363
1364
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    azp: Optional[torch.Tensor] = None,
    symmetric: bool = True
1365
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
1366
    """
1367
    Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
1368
1369
1370

    Args:
        input: The input tensor to be quantized to int8.
1371
1372
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
1373
1374
1375
        azp: Optional zero-point for the int8 quantization.
            Must be provided for asymmetric quantization if `scale` is provided.
        symmetric: Whether to use symmetric quantization (scale only, azp ignored).
1376
1377

    Returns:
1378
      tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
1379
    """
1380
1381
1382
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
1383
        assert symmetric == (
1384
1385
            azp
            is None), "azp must only be provided for asymmetric quantization."
1386
        torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
1387
        return output, scale, azp
1388
1389
1390
1391
1392

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
1393
1394
    input_azp = None if symmetric else torch.empty_like(input_scales,
                                                        dtype=torch.int32)
1395
1396
    torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(),
                                           input_scales, input_azp)
1397
    return output, input_scales, input_azp
1398
1399


1400
1401
1402
1403
1404
1405
1406
1407
1408
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


1409
# gguf
1410
1411
1412
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
                    dtype: Optional[torch.dtype]) -> torch.Tensor:
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype)
1413
1414
1415
1416
1417
1418
1419


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
1420
) -> torch.Tensor:
1421
1422
1423
1424
1425
1426
1427
1428
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
1429
) -> torch.Tensor:
1430
1431
1432
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
def ggml_moe_a8(
    X: torch.Tensor,
    W: torch.Tensor,
    sorted_token_ids: torch.Tensor,
    expert_ids: torch.Tensor,
    num_tokens_post_padded: torch.Tensor,
    quant_type: int,
    row: int,
    top_k: int,
    tokens: int,
) -> torch.Tensor:
    return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids,
                                    num_tokens_post_padded, quant_type, row,
                                    top_k, tokens)


1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
def ggml_moe_a8_vec(
    X: torch.Tensor,
    W: torch.Tensor,
    topk_ids: torch.Tensor,
    top_k: int,
    quant_type: int,
    row: torch.SymInt,
    tokens: torch.SymInt,
) -> torch.Tensor:
    return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row,
                                        tokens)


1462
1463
1464
1465
def ggml_moe_get_block_size(quant_type: int) -> int:
    return torch.ops._C.ggml_moe_get_block_size(quant_type)


1466
1467
1468
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                      bias_: Optional[torch.Tensor],
1469
1470
1471
1472
                      conv_states: Optional[torch.Tensor],
                      query_start_loc: Optional[torch.Tensor],
                      cache_indices: Optional[torch.Tensor],
                      has_initial_state: Optional[torch.Tensor],
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
                      silu_activation: bool, pad_slot_id: int):
    torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
                                   query_start_loc, cache_indices,
                                   has_initial_state, silu_activation,
                                   pad_slot_id)


def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
                         weight: torch.Tensor, bias_: Optional[torch.Tensor],
                         silu_activation: bool,
                         cache_seqlens: Optional[torch.Tensor],
                         conv_state_indices: Optional[torch.Tensor],
                         pad_slot_id: int):
    torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
                                      silu_activation, cache_seqlens,
                                      conv_state_indices, pad_slot_id)


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
                       B: torch.Tensor, C: torch.Tensor,
                       D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
                       delta_bias_: Optional[torch.Tensor],
                       delta_softplus: bool,
                       query_start_loc: Optional[torch.Tensor],
                       cache_indices: Optional[torch.Tensor],
                       has_initial_state: Optional[torch.Tensor],
                       ssm_states: torch.Tensor, pad_slot_id: int):
1500
1501
1502
    torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
                                    delta_softplus, query_start_loc,
                                    cache_indices, has_initial_state,
1503
                                    ssm_states, pad_slot_id)
1504
1505


1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor,
          rows_per_block: int) -> torch.Tensor:
    return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)


def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
    return torch.ops._rocm_C.wvSplitK(a, b, cu_count)


def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
              scale_a: torch.Tensor, scale_b: torch.Tensor,
              cu_count: int) -> torch.Tensor:
    out = torch.empty((b.shape[0], a.shape[0]),
                      dtype=out_dtype,
                      device=b.device)
    torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
    return out


1526
# moe
1527
1528
1529
1530
def moe_sum(input: torch.Tensor, output: torch.Tensor):
    torch.ops._moe_C.moe_sum(input, output)


1531
1532
1533
1534
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
1535
1536
1537
    torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
                                          sorted_token_ids, experts_ids,
                                          num_tokens_post_pad)
1538
1539


1540
1541
1542
1543
1544
1545
1546
1547
def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
                   b_qweight: torch.Tensor, b_scales: torch.Tensor,
                   b_qzeros: Optional[torch.Tensor],
                   topk_weights: Optional[torch.Tensor],
                   sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor,
                   num_tokens_post_pad: torch.Tensor, top_k: int,
                   BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
                   bit: int) -> torch.Tensor:
1548
1549
1550
1551
    if not current_platform.is_cuda():
        raise NotImplementedError(
            "The optimized moe_wna16_gemm kernel is only "
            "available on CUDA platforms")
1552
1553
1554
1555
1556
1557
1558
    torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales,
                                    b_qzeros, topk_weights, sorted_token_ids,
                                    experts_ids, num_tokens_post_pad, top_k,
                                    BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
                                    bit)


1559
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
1560
                 token_expert_indices: torch.Tensor,
1561
                 gating_output: torch.Tensor) -> None:
1562
1563
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, token_expert_indices,
                                  gating_output)
1564
1565


1566
1567
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
                          b_qweight: torch.Tensor, b_scales: torch.Tensor,
1568
                          global_scale: Optional[torch.Tensor],
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
                          b_qzeros: Optional[torch.Tensor],
                          g_idx: Optional[torch.Tensor],
                          perm: Optional[torch.Tensor],
                          workspace: torch.Tensor,
                          sorted_token_ids: torch.Tensor,
                          expert_ids: torch.Tensor,
                          num_tokens_past_padded: torch.Tensor,
                          topk_weights: torch.Tensor, moe_block_size: int,
                          top_k: int, mul_topk_weights: bool, is_ep: bool,
                          b_q_type: ScalarType, size_m: int, size_n: int,
                          size_k: int, is_k_full: bool, use_atomic_add: bool,
                          use_fp32_reduce: bool,
                          is_zp_float: bool) -> torch.Tensor:
    return torch.ops._moe_C.moe_wna16_marlin_gemm(
1583
1584
1585
1586
1587
        input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx,
        perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded,
        topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep,
        b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add,
        use_fp32_reduce, is_zp_float)
1588
1589


1590
1591
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

1592
    @register_fake("_moe_C::marlin_gemm_moe")
1593
1594
1595
1596
    def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
                             sorted_ids: torch.Tensor,
                             topk_weights: torch.Tensor,
                             topk_ids: torch.Tensor, b_scales: torch.Tensor,
1597
1598
                             b_zero_points: torch.Tensor, g_idx: torch.Tensor,
                             perm: torch.Tensor, workspace: torch.Tensor,
1599
1600
1601
1602
                             b_q_type: ScalarType, size_m: torch.SymInt,
                             size_n: torch.SymInt, size_k: torch.SymInt,
                             is_k_full: bool, num_experts: int, topk: int,
                             moe_block_size: int, replicate_input: bool,
1603
1604
1605
1606
1607
                             apply_weights: bool) -> torch.Tensor:
        return torch.empty((size_m, topk, size_n),
                           dtype=a.dtype,
                           device=a.device)

1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
    @register_fake("_moe_C::moe_wna16_marlin_gemm")
    def moe_wna16_marlin_gemm_fake(input: torch.Tensor,
                                   output: Optional[torch.Tensor],
                                   b_qweight: torch.Tensor,
                                   b_scales: torch.Tensor,
                                   b_qzeros: Optional[torch.Tensor],
                                   g_idx: Optional[torch.Tensor],
                                   perm: Optional[torch.Tensor],
                                   workspace: torch.Tensor,
                                   sorted_token_ids: torch.Tensor,
                                   expert_ids: torch.Tensor,
                                   num_tokens_past_padded: torch.Tensor,
                                   topk_weights: torch.Tensor,
                                   moe_block_size: int, top_k: int,
                                   mul_topk_weights: bool, is_ep: bool,
                                   b_q_type: ScalarType, size_m: int,
                                   size_n: int, size_k: int, is_k_full: bool,
                                   use_atomic_add: bool, use_fp32_reduce: bool,
                                   is_zp_float: bool) -> torch.Tensor:
        return torch.empty((size_m * top_k, size_n),
                           dtype=input.dtype,
                           device=input.device)

1631

1632
1633
1634
1635
1636
1637
1638
def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
1639
1640
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
1641
) -> None:
1642
1643
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
1644
                                             kv_cache_dtype, k_scale, v_scale)
1645
1646


1647
1648
1649
1650
1651
1652
1653
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
1654
1655
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
1656
) -> None:
1657
1658
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
1659
1660
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
1661
1662


1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
def concat_and_cache_mla(
    kv_c: torch.Tensor,
    k_pe: torch.Tensor,
    kv_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
    scale: torch.Tensor,
) -> None:
    torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache,
                                                slot_mapping, kv_cache_dtype,
                                                scale)


1676
1677
def copy_blocks(key_caches: list[torch.Tensor],
                value_caches: list[torch.Tensor],
1678
                block_mapping: torch.Tensor) -> None:
1679
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
1680
1681


1682
def copy_blocks_mla(kv_caches: list[torch.Tensor],
1683
1684
1685
1686
                    block_mapping: torch.Tensor) -> None:
    torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)


1687
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
1688
                block_mapping: torch.Tensor) -> None:
1689
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
1690
1691


1692
1693
1694
1695
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
1696
1697
1698
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
def gather_cache(src_cache: torch.Tensor,
                 dst: torch.Tensor,
                 block_table: torch.Tensor,
                 cu_seq_lens: torch.Tensor,
                 batch_size: int,
                 seq_starts: Optional[torch.Tensor] = None) -> None:
    torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
                                        cu_seq_lens, batch_size, seq_starts)


1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
1720
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
1721
                   rank: int, fully_connected: bool) -> int:
1722
    return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
1723
                                                 fully_connected)
1724
1725


1726
1727
1728
1729
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
               reg_buffer_sz_bytes: int) -> None:
    torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
                                      reg_buffer_sz_bytes)
1730

1731
1732
1733
1734
1735
1736
1737
1738
1739

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


1740
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
1741
    return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
1742
1743


1744
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
1745
1746
1747
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


1748
1749
def register_graph_buffers(fa: int, handles: list[list[int]],
                           offsets: list[list[int]]) -> None:
1750
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
1751
1752


1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
    return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)


def open_mem_handle(mem_handle: torch.Tensor):
    return torch.ops._C_custom_ar.open_mem_handle(mem_handle)


def free_shared_buffer(ptr: int) -> None:
    torch.ops._C_custom_ar.free_shared_buffer(ptr)


1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
# quick all reduce
def init_custom_qr(rank: int,
                   world_size: int,
                   qr_max_size: Optional[int] = None) -> int:
    return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size)


def qr_destroy(fa: int) -> None:
    torch.ops._C_custom_ar.qr_destroy(fa)


def qr_all_reduce(fa: int,
                  inp: torch.Tensor,
                  out: torch.Tensor,
                  quant_level: int,
                  cast_bf2half: bool = False) -> None:
    torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level,
                                         cast_bf2half)


def qr_get_handle(fa: int) -> torch.Tensor:
    return torch.ops._C_custom_ar.qr_get_handle(fa)


def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
    return torch.ops._C_custom_ar.qr_open_handles(fa, handles)


def qr_max_size() -> int:
    return torch.ops._C_custom_ar.qr_max_size()


1797
1798
1799
1800
def get_flash_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
1801
) -> tuple[torch.Tensor, torch.Tensor]:
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Return:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    return torch.ops._C.get_flash_mla_metadata(cache_seqlens,
                                               num_heads_per_head_k,
                                               num_heads_k)


def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
1827
) -> tuple[torch.Tensor, torch.Tensor]:
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head_dim of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Return:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1]**(-0.5)
    out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
    )
    return out, softmax_lse
1859
1860
1861
1862
1863
1864
1865
1866
1867


def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
                       q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
                       seq_lens: torch.Tensor, page_table: torch.Tensor,
                       scale: float) -> torch.Tensor:
    torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
                                    seq_lens, page_table, scale)
    return out
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916


if hasattr(torch.ops._C, "weight_packed_linear"):

    @register_fake("_C::weight_packed_linear")
    def weight_packed_linear_fake(mat1: torch.Tensor, mat2: torch.Tensor,
                                  bias: Optional[torch.Tensor],
                                  is_vnni: bool) -> torch.Tensor:
        return torch.empty((mat1.size(0), mat2.size(0)),
                           dtype=mat1.dtype,
                           device=mat2.device)


if hasattr(torch.ops._C, "fused_experts_cpu"):

    @register_fake("_C::fused_experts_cpu")
    def fused_experts_cpu_fake(
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        inplace: bool,
        use_int8_w8a8: bool,
        use_fp8_w8a16: bool,
        w1_scale: Optional[torch.Tensor],
        w2_scale: Optional[torch.Tensor],
        block_size: Optional[list[int]],
        a1_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],
        is_vnni: bool,
    ) -> torch.Tensor:
        return torch.empty_like(hidden_states)


if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):

    @register_fake("_C::int8_scaled_mm_with_quant")
    def int8_scaled_mm_with_quant_fake(
        mat1: torch.Tensor,
        mat2: torch.Tensor,
        scales2: torch.Tensor,
        bias: Optional[torch.Tensor],
        out_dtype: torch.dtype,
        is_vnni: bool,
    ) -> torch.Tensor:
        M = mat1.size(0)
        N = mat2.size(0)
        return torch.empty((M, N), dtype=out_dtype)