_custom_ops.py 41.4 KB
Newer Older
1
import contextlib
2
import importlib
3
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
4
5

import torch
6
import torch.library
7

8
import vllm.envs as envs
9
from vllm.logger import init_logger
10
from vllm.platforms import current_platform
11
from vllm.scalar_type import ScalarType
12
13
14

logger = init_logger(__name__)

15
if not current_platform.is_tpu() and not current_platform.is_hpu():
16
17
18
19
    try:
        import vllm._C
    except ImportError as e:
        logger.warning("Failed to import from vllm._C with %r", e)
20

21
supports_moe_ops = False
22
with contextlib.suppress(ImportError):
23
    import vllm._moe_C  # noqa: F401
24
    supports_moe_ops = True
25

26
if TYPE_CHECKING:
27
28
29
30
31
32
33
34
35

    def register_fake(fn):
        return lambda name: fn
else:
    try:
        from torch.library import register_fake
    except ImportError:
        from torch.library import impl_abstract as register_fake

36
37
38

# activation ops
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
39
    torch.ops._C.gelu_and_mul(out, x)
40
41
42


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
43
    torch.ops._C.gelu_tanh_and_mul(out, x)
44
45


46
47
48
49
50
51
def fatrelu_and_mul(out: torch.Tensor,
                    x: torch.Tensor,
                    threshold: float = 0.0) -> None:
    torch.ops._C.fatrelu_and_mul(out, x, threshold)


52
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
53
    torch.ops._C.gelu_fast(out, x)
54
55
56


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
57
    torch.ops._C.gelu_new(out, x)
58
59


60
61
62
63
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


64
65
66
67
68
69
70
71
72
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
73
    seq_lens: torch.Tensor,
74
    block_size: int,
75
    max_seq_len: int,
76
77
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
78
79
    k_scale: float,
    v_scale: float,
80
81
82
83
84
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
85
) -> None:
86
    torch.ops._C.paged_attention_v1(
87
88
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
89
90
91
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
92
93
94
95
96
97
98
99
100
101
102
103
104


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
105
    seq_lens: torch.Tensor,
106
    block_size: int,
107
    max_seq_len: int,
108
109
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
110
111
    k_scale: float,
    v_scale: float,
112
113
114
115
116
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
117
) -> None:
118
    torch.ops._C.paged_attention_v2(
119
120
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
121
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
122
123
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
124
125


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def paged_attention_rocm(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
    block_size: int,
    max_seq_len: int,
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
142
143
    k_scale: float,
    v_scale: float,
144
145
146
147
148
) -> None:
    torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
                                      key_cache, value_cache, num_kv_heads,
                                      scale, block_tables, seq_lens,
                                      block_size, max_seq_len, alibi_slopes,
149
                                      kv_cache_dtype, k_scale, v_scale)
150
151


152
153
154
155
156
157
158
159
160
# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
161
162
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
163
164
165
166
167
168
169


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
170
171
172
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
173
174
175
176
177


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
178
    torch.ops._C.rms_norm(out, input, weight, epsilon)
179
180
181
182


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
183
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
184
185


186
187
188
189
190
191
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
                           input_tokens: torch.Tensor,
                           sampled_token_ids: torch.Tensor,
                           input_positions: torch.Tensor,
                           seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                           block_tables: torch.Tensor) -> None:
192
    """Advance a step on GPU for existing inputs for a multi-step runner"""
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
                                               block_size, input_tokens,
                                               sampled_token_ids,
                                               input_positions, seq_lens,
                                               slot_mapping, block_tables)


def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
                            input_tokens: torch.Tensor,
                            sampled_token_ids: torch.Tensor,
                            input_positions: torch.Tensor,
                            seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                            block_tables: torch.Tensor,
                            paged_kv_indices: torch.Tensor,
                            paged_kv_indptr: torch.Tensor,
                            paged_kv_last_page_len: torch.Tensor,
                            block_table_bound: torch.Tensor) -> None:

    return torch.ops._C.advance_step_flashinfer(
        num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
        input_positions, seq_lens, slot_mapping, block_tables,
        paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
        block_table_bound)
216
217


218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    quant_dtype: torch.dtype,
    scale_ub: Optional[torch.Tensor] = None,
    residual: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    output = torch.empty_like(input, dtype=quant_dtype)
    scales = torch.empty((input.numel() // input.shape[-1], 1),
                         device=input.device,
                         dtype=torch.float32)

    torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight,
                                                  scales, epsilon, scale_ub,
                                                  residual)
    return output, scales


238
239
240
241
242
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
243
244
245
246
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_dequantize_triton)
        return awq_dequantize_triton(qweight, scales, zeros)
247
248
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
249
250
251
252


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
253
254
255
256
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_gemm_triton)
        return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
257
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
258
259
260
261
262
263
264


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
265
266
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
267
268


269
if hasattr(torch.ops._C, "gptq_gemm"):
270

271
    @register_fake("_C::gptq_gemm")
272
273
274
275
276
277
278
279
280
    def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_gptq_qzeros: torch.Tensor,
                        b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
                        use_exllama: bool, bit: int) -> torch.Tensor:
        return torch.empty((a.size(0), b_q_weight.size(1)),
                           dtype=a.dtype,
                           device=a.device)


281
282
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
283
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
284
285
286
287
288
289


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
290
291
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
292
293


294
295
296
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
297
298
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
299
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
300
                                            workspace, b_q_type.id, size_m,
301
                                            size_n, size_k)
302
303


304
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
305

306
    @register_fake("_C::gptq_marlin_24_gemm")
307
308
309
    def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                                  b_meta: torch.Tensor, b_scales: torch.Tensor,
                                  workspace: torch.Tensor,
310
311
312
                                  b_q_type: ScalarType, size_m: torch.SymInt,
                                  size_n: torch.SymInt,
                                  size_k: torch.SymInt) -> torch.Tensor:
313
314
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

315
    @register_fake("_C::gptq_marlin_gemm")
316
317
318
319
320
321
322
323
    def _gptq_marlin_gemm_fake(a: torch.Tensor,
                               b_q_weight: torch.Tensor,
                               b_scales: torch.Tensor,
                               b_zeros: torch.Tensor,
                               g_idx: torch.Tensor,
                               perm: torch.Tensor,
                               workspace: torch.Tensor,
                               b_q_type: ScalarType,
324
325
326
                               size_m: torch.SymInt,
                               size_n: torch.SymInt,
                               size_k: torch.SymInt,
327
328
                               is_k_full: bool,
                               has_zp: bool = False,
329
330
                               use_fp32_reduce: bool = False,
                               is_zp_float: bool = False) -> torch.Tensor:
331
332
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

333
    @register_fake("_C::marlin_qqq_gemm")
334
335
336
    def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              s_tok: torch.Tensor, s_ch: torch.Tensor,
                              s_group: torch.Tensor, workspace: torch.Tensor,
337
338
                              size_m: torch.SymInt, size_n: torch.SymInt,
                              size_k: torch.SymInt) -> torch.Tensor:
339
340
341
342
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

343
    @register_fake("_C::marlin_gemm")
344
345
    def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                          b_scales: torch.Tensor, workspace: torch.Tensor,
346
347
                          size_m: torch.SymInt, size_n: torch.SymInt,
                          size_k: torch.SymInt) -> torch.Tensor:
348
349
350
351
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

352
    @register_fake("_C::awq_dequantize")
353
    def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
354
355
                             zeros: torch.Tensor, split_k_iters: torch.SymInt,
                             thx: int, thy: int) -> torch.Tensor:
356
357
358
359
360
361
362
        in_c = qweight.size(0)
        qout_c = qweight.size(1)
        out_c = qout_c * 8
        return torch.empty((in_c, out_c),
                           dtype=scales.dtype,
                           device=scales.device)

363
    @register_fake("_C::awq_gemm")
364
365
    def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
                       qzeros: torch.Tensor, scales: torch.Tensor,
366
                       split_k_iters: torch.SymInt) -> torch.Tensor:
367
368
369
370
371
        num_in_feats = input.size(0)
        return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
                           dtype=input.dtype,
                           device=input.device).sum(0)

372
    @register_fake("_C::aqlm_gemm")
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
                        codebooks: torch.Tensor, scales: torch.Tensor,
                        codebook_partition_sizes: List[int],
                        bias: Optional[torch.Tensor]) -> torch.Tensor:
        out_features = codes.size(0) * codebooks.size(2)
        flat_input = input.reshape((-1, input.size(-1)))
        flat_output = torch.empty((flat_input.size(0), out_features),
                                  dtype=input.dtype,
                                  device=input.device)

        output_sizes = list(input.shape)
        output_sizes.pop()
        output_sizes.append(-1)
        return flat_output.reshape(tuple(output_sizes))

388
    @register_fake("_C::aqlm_dequant")
389
390
391
392
393
394
395
396
397
    def _aqlm_dequant_fake(
            codes: torch.Tensor, codebooks: torch.Tensor,
            codebook_partition_sizes: List[int]) -> torch.Tensor:
        in_features = codes.size(1) * 8
        out_features = codes.size(0)
        return torch.empty((out_features, in_features),
                           dtype=codebooks.dtype,
                           device=codebooks.device)

398
    @register_fake("_C::fp8_marlin_gemm")
399
400
    def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              b_scales: torch.Tensor, workspace: torch.Tensor,
401
402
403
                              num_bits: int, size_m: torch.SymInt,
                              size_n: torch.SymInt,
                              size_k: torch.SymInt) -> torch.Tensor:
404
405
        return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

406
407
    @register_fake("_C::machete_mm")
    def machete_mm_fake(
408
        a: torch.Tensor,
409
        # b_q Should be the tensor returned by machete_prepack_B
410
        b_q: torch.Tensor,
411
        b_type: ScalarType,
412
413
414
        out_type: Optional[torch.dtype] = None,
        b_group_scales: Optional[torch.Tensor] = None,
        b_group_zeros: Optional[torch.Tensor] = None,
415
        b_group_size: Optional[int] = None,
416
417
        b_channel_scales: Optional[torch.Tensor] = None,
        a_token_scales: Optional[torch.Tensor] = None,
418
419
420
421
422
423
        schedule: Optional[str] = None,
    ) -> torch.Tensor:
        m = a.size(0)
        n = b_q.size(1)
        return torch.empty((m, n), device=a.device, dtype=a.dtype)

424
    @register_fake("_C::machete_prepack_B")
425
426
427
    def machete_prepack_B_fake(
            b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
            group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
428
429
        return torch.empty_like(b_q_weight,
                                memory_format=torch.contiguous_format)
430
431


432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
if hasattr(torch.ops._C, "ggml_dequantize"):

    @register_fake("_C::ggml_dequantize")
    def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
                              m: torch.SymInt,
                              n: torch.SymInt) -> torch.Tensor:
        return torch.empty((m, n), dtype=torch.float16, device=W.device)

    @register_fake("_C::ggml_mul_mat_vec_a8")
    def _ggml_mul_mat_vec_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: torch.SymInt,
    ) -> torch.Tensor:
        return torch.empty((1, row), dtype=torch.float16, device=W.device)

    @register_fake("_C::ggml_mul_mat_a8")
    def _ggml_mul_mat_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: torch.SymInt,
    ) -> torch.Tensor:
        batch = X.size(0)
        return torch.empty((batch, row), dtype=torch.float16, device=W.device)


460
# cutlass
461
462
463
464
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


465
466
467
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
468
                      scale_b: torch.Tensor,
469
                      out_dtype: torch.dtype,
470
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
471
472
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
473
474
    assert bias is None or bias.shape[0] == b.shape[
        1] and bias.dtype == out_dtype
475
476
477

    m = a.shape[0]
    n = b.shape[1]
478
479
480
481
482
483
484
485

    if current_platform.is_rocm():
        triton_scaled_mm_module = importlib.import_module(
            "vllm.model_executor.layers.quantization.compressed_tensors."
            "triton_scaled_mm")
        triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
        return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

486
487
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

488
489
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

490
491
492
    return out


493
494
495
496
497
498
499
500
def cutlass_scaled_mm_azp(a: torch.Tensor,
                          b: torch.Tensor,
                          scale_a: torch.Tensor,
                          scale_b: torch.Tensor,
                          out_dtype: torch.dtype,
                          azp_adj: torch.Tensor,
                          azp: Optional[torch.Tensor] = None,
                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
501
502
503
504
505
    """
    :param azp_adj: In the per-tensor case, this should include the azp.
    Always per-channel.
    :param azp: Only set in the per-token case. Per-token if set.
    """
506
507
508
509
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.numel(
    ) == b.shape[1] and bias.dtype == out_dtype
510
    assert azp is None or azp.numel() == a.shape[0]
511
512
513
514
515
516
517
518
519
520

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
                                       azp, bias)
    return out


521
522
523
524
525
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_sparse_scaled_mm_supported(
        cuda_device_capability)


526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
def cutlass_sparse_compress(a: torch.Tensor) \
    -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compresses a sparse matrix for use with Cutlass sparse operations.

    This function takes a dense tensor and compresses it into two components:
    non-zero elements and metadata. The compressed representation is compatible
    with Cutlass sparse kernels.

    Args:
        a (torch.Tensor): 
            The input tensor to be compressed. Must have one of the following data types:
            - `torch.int8`
            - `torch.float8_e4m3fn`
            - `torch.bfloat16`
            - `torch.float16`

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: 
            A tuple containing:
            - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
            - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.

    Raises:
        ValueError: If the compression operation fails.

    Notes:
        - The `a_meta` tensor has a data type of `torch.uint8`.
        - Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
        - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
        - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
    """
    assert (a.dtype in [
        torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
    ])
    assert (a.is_contiguous())

    # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
    elemsPerMetaElem = 4

    m = a.shape[0]
    k = a.shape[1]
    assert (k % 2 == 0)
    a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
    a_meta = torch.empty((m, k // 2 // elemsPerMetaElem),
                         dtype=torch.uint8,
                         device=a.device)

    if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
        raise ValueError

    assert (a_nzs.is_contiguous())
    assert (a_meta.is_contiguous())

    return a_nzs, a_meta


def cutlass_scaled_sparse_mm(
        a: torch.Tensor,
        bt_nzs: torch.Tensor,
        bt_meta: torch.Tensor,
        scale_a: torch.Tensor,
        scale_b: torch.Tensor,
        out_dtype: torch.dtype,
        bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Performs a scaled sparse matrix multiplication using Cutlass.

    Steps:
    1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
    `a = torch.randn((m, k), device='cuda')`.

    2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
    `b = torch.randn((k, n), device='cuda')`.

    3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
    `b = prune_to_2_4(b, dim=0)`.

    4. Compress the transposed sparse matrix `b.t()`:
    `bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.

    5. Perform sparse matrix multiplication using the compressed matrix,
    applying scaling factors for `a` and `b`, and the output data type:
    `out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.

    Returns:
    - The result of the scaled sparse matrix multiplication.
    """
    assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.shape[0] == bt_nzs.shape[0] \
        and bias.dtype == out_dtype

    m = a.shape[0]
    n = bt_nzs.shape[0]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
                                          scale_b, bias)

    return out


629
630
631
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
632
              codebook_partition_sizes: List[int],
633
              bias: Optional[torch.Tensor]) -> torch.Tensor:
634
635
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
636
637
638


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
639
                 codebook_partition_sizes: List[int]) -> torch.Tensor:
640
641
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
642
643


644
645
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
646
647
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
648
649
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
650
651


652
653
654
655
656
657
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


658
659
660
661
662
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                           size_k: int, size_n: int,
                           num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
663
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
664
665
666
667
668
669
670
671
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
                                                    size_k, size_n, num_bits)
    return output


672
673
674
675
676
677
678
679
680
681
682
683
684
685
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                          size_k: int, size_n: int,
                          num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
                                                   size_n, num_bits)
    return output


686
687
688
689
690
691
692
693
694
695
696
697
698
def gptq_marlin_gemm(a: torch.Tensor,
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
                     b_zeros: torch.Tensor,
                     g_idx: torch.Tensor,
                     perm: torch.Tensor,
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
                     is_k_full: bool,
                     has_zp: bool = False,
699
700
                     use_fp32_reduce: bool = False,
                     is_zp_float: bool = False) -> torch.Tensor:
701
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
702
                                         g_idx, perm, workspace, b_q_type.id,
703
                                         size_m, size_n, size_k, is_k_full,
704
                                         has_zp, use_fp32_reduce, is_zp_float)
705
706


707
708
709
710
711
712
713
714
715
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


716
# machete
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
def machete_supported_schedules(
        a_type: torch.dtype,
        b_type: ScalarType,
        group_scales_type: Optional[torch.dtype],
        group_zeros_type: Optional[torch.dtype] = None,
        channel_scales_type: Optional[torch.dtype] = None,
        token_scales_type: Optional[torch.dtype] = None,
        out_type: Optional[torch.dtype] = None) -> List[str]:
    return torch.ops._C.machete_supported_schedules(
        a_type, b_type.id, group_scales_type, group_zeros_type,
        channel_scales_type, token_scales_type, out_type)


def machete_mm(
        a: torch.Tensor,
        # b_q Should be the tensor returned by machete_prepack_B
        b_q: torch.Tensor,
        b_type: ScalarType,
        out_type: Optional[torch.dtype] = None,
        b_group_scales: Optional[torch.Tensor] = None,
        b_group_zeros: Optional[torch.Tensor] = None,
        b_group_size: Optional[int] = None,
        b_channel_scales: Optional[torch.Tensor] = None,
        a_token_scales: Optional[torch.Tensor] = None,
        schedule: Optional[str] = None) -> torch.Tensor:
    return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
                                   b_group_zeros, b_group_size,
                                   b_channel_scales, a_token_scales, schedule)


def machete_prepack_B(
        b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
        group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
    return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id,
                                          group_scales_type)
752
753


754
if hasattr(torch.ops._C, "permute_cols"):
755

756
    @register_fake("_C::permute_cols")
757
758
759
760
761
762
763
764
765
    def _permute_cols_fake(a: torch.Tensor,
                           perm: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(a)


def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
    return torch.ops._C.permute_cols(a, perm)


766
# fp8
767
768
769
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
770
    num_token_padding: Optional[int] = None,
771
    scale_ub: Optional[torch.Tensor] = None,
772
    use_per_token_if_dynamic: bool = False,
773
) -> Tuple[torch.Tensor, torch.Tensor]:
774
775
776
777
778
779
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
780
    optional padding of the output tensors for downstream kernels that
781
782
783
784
785
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
786
        scale_ub: Optional upper bound for scaling factor in dynamic
787
            per token case
788
        num_token_padding: If specified, pad the first dimension
789
            of the output to at least this value.
790
        use_per_token_if_dynamic: Whether to do per_tensor or per_token
791
            in the dynamic quantization case.
792
793
794
795
796

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
797
798
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
799
    shape: Union[Tuple[int, int], torch.Size] = input.shape
800
    # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
801
802
    out_dtype: torch.dtype = torch.float8_e4m3fnuz \
            if current_platform.is_rocm() else torch.float8_e4m3fn
803
804
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
805
    output = torch.empty(shape, device=input.device, dtype=out_dtype)
806

807
    if scale is None:
808
        if use_per_token_if_dynamic:
809
            scale = torch.empty((shape[0], 1),
810
811
812
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
813
                output, input, scale, scale_ub)
814
815
816
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
817
    else:
818
819
        # num_token_padding not implemented for this case
        assert (scale.numel() == 1 or num_token_padding is None)
820
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
821

822
    return output, scale
823
824


825
# int8
826
def scaled_int8_quant(
827
828
829
830
831
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    azp: Optional[torch.Tensor] = None,
    symmetric: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
832
    """
833
    Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
834
835
836

    Args:
        input: The input tensor to be quantized to int8.
837
838
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
839
840
841
        azp: Optional zero-point for the int8 quantization.
            Must be provided for asymmetric quantization if `scale` is provided.
        symmetric: Whether to use symmetric quantization (scale only, azp ignored).
842
843

    Returns:
844
      Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
845
    """
846
847
848
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
849
850
851
852
        assert symmetric == (
            azp is
            None), "azp must only be provided for asymmetric quantization."
        torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
853
        return output, scale, azp
854
855
856
857
858

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
859
860
861
862
863
    input_azp = None if symmetric else torch.empty_like(input_scales,
                                                        dtype=torch.int32)
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
                                           input_azp)
    return output, input_scales, input_azp
864
865


866
867
868
869
870
871
872
873
874
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


875
# gguf
876
877
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
                    n: int) -> torch.Tensor:
878
879
880
881
882
883
884
885
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
886
) -> torch.Tensor:
887
888
889
890
891
892
893
894
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
895
) -> torch.Tensor:
896
897
898
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


899
900
901
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                      bias_: Optional[torch.Tensor],
902
903
904
905
                      conv_states: Optional[torch.Tensor],
                      query_start_loc: Optional[torch.Tensor],
                      cache_indices: Optional[torch.Tensor],
                      has_initial_state: Optional[torch.Tensor],
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
                      silu_activation: bool, pad_slot_id: int):
    torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
                                   query_start_loc, cache_indices,
                                   has_initial_state, silu_activation,
                                   pad_slot_id)


def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
                         weight: torch.Tensor, bias_: Optional[torch.Tensor],
                         silu_activation: bool,
                         cache_seqlens: Optional[torch.Tensor],
                         conv_state_indices: Optional[torch.Tensor],
                         pad_slot_id: int):
    torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
                                      silu_activation, cache_seqlens,
                                      conv_state_indices, pad_slot_id)


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
                       B: torch.Tensor, C: torch.Tensor,
                       D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
                       delta_bias_: Optional[torch.Tensor],
                       delta_softplus: bool,
                       query_start_loc: Optional[torch.Tensor],
                       cache_indices: Optional[torch.Tensor],
                       has_initial_state: Optional[torch.Tensor],
                       ssm_states: torch.Tensor, pad_slot_id: int):
933
934
935
    torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
                                    delta_softplus, query_start_loc,
                                    cache_indices, has_initial_state,
936
                                    ssm_states, pad_slot_id)
937
938


939
# moe
940
941
942
943
def moe_sum(input: torch.Tensor, output: torch.Tensor):
    torch.ops._moe_C.moe_sum(input, output)


944
945
946
947
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
948
949
950
    torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
                                          sorted_token_ids, experts_ids,
                                          num_tokens_post_pad)
951
952
953
954
955
956
957


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
958
959


960
961
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

962
    @register_fake("_moe_C::marlin_gemm_moe")
963
964
965
966
    def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
                             sorted_ids: torch.Tensor,
                             topk_weights: torch.Tensor,
                             topk_ids: torch.Tensor, b_scales: torch.Tensor,
967
968
                             b_zero_points: torch.Tensor, g_idx: torch.Tensor,
                             perm: torch.Tensor, workspace: torch.Tensor,
969
970
971
972
                             b_q_type: ScalarType, size_m: torch.SymInt,
                             size_n: torch.SymInt, size_k: torch.SymInt,
                             is_k_full: bool, num_experts: int, topk: int,
                             moe_block_size: int, replicate_input: bool,
973
974
975
976
977
978
                             apply_weights: bool) -> torch.Tensor:
        return torch.empty((size_m, topk, size_n),
                           dtype=a.dtype,
                           device=a.device)


979
980
981
982
983
984
985
def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
986
987
    k_scale: float,
    v_scale: float,
988
) -> None:
989
990
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
991
                                             kv_cache_dtype, k_scale, v_scale)
992
993


994
995
996
997
998
999
1000
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
1001
1002
    k_scale: float,
    v_scale: float,
1003
) -> None:
1004
1005
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
1006
1007
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
1008
1009


1010
1011
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
1012
                block_mapping: torch.Tensor) -> None:
1013
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
1014
1015
1016


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
1017
                block_mapping: torch.Tensor) -> None:
1018
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
1019
1020


1021
1022
1023
1024
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
1039
1040
1041
1042
def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor,
                   rank: int, full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
                                                 full_nvlink)
1043
1044


1045
1046
1047
1048
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
               reg_buffer_sz_bytes: int) -> None:
    torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
                                      reg_buffer_sz_bytes)
1049

1050
1051
1052
1053
1054
1055
1056
1057
1058

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


1059
1060
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
1061
1062


1063
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
1064
1065
1066
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


1067
def register_graph_buffers(fa: int, handles: List[List[int]],
1068
1069
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)