_custom_ops.py 38.9 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
4
5

import torch
6
import torch.library
7

8
import vllm.envs as envs
9
from vllm._core_ext import ScalarType
10
from vllm.logger import init_logger
11
from vllm.platforms import current_platform
12
13
14

logger = init_logger(__name__)

15
16
17
18
19
if not current_platform.is_tpu():
    try:
        import vllm._C
    except ImportError as e:
        logger.warning("Failed to import from vllm._C with %r", e)
20

21
22
23
if current_platform.is_rocm():
    import vllm._rocm_C  # noqa: F401

24
supports_moe_ops = False
25
with contextlib.suppress(ImportError):
26
    import vllm._moe_C  # noqa: F401
27
    supports_moe_ops = True
28

29
30
31
32
33
34
35
36
37
38
if TYPE_CHECKING:

    def register_fake(fn):
        return lambda name: fn
else:
    try:
        from torch.library import register_fake
    except ImportError:
        from torch.library import impl_abstract as register_fake

39

40
41
42
43
44
45
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
46
47
48
49
50
51
52
53
54

        except NotImplementedError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Not implemented or built, mostly likely because the current current device "
                "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
                "incorrectly while building)")
            logger.error(msg, fn.__name__, e)
            raise NotImplementedError(msg % (fn.__name__, e)) from e
55
56
57
58
59
60
61
62
63
64
65
66
67
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


68
69
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
70
    torch.ops._C.silu_and_mul(out, x)
71
72
73


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
74
    torch.ops._C.gelu_and_mul(out, x)
75
76
77


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
78
    torch.ops._C.gelu_tanh_and_mul(out, x)
79
80
81


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
82
    torch.ops._C.gelu_fast(out, x)
83
84
85


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
86
    torch.ops._C.gelu_new(out, x)
87
88


89
90
91
92
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


93
94
95
96
97
98
99
100
101
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
102
    seq_lens: torch.Tensor,
103
    block_size: int,
104
    max_seq_len: int,
105
106
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
107
108
    k_scale: float,
    v_scale: float,
109
110
111
112
113
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
114
) -> None:
115
    torch.ops._C.paged_attention_v1(
116
117
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
118
119
120
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
121
122
123
124
125
126
127
128
129
130
131
132
133


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
134
    seq_lens: torch.Tensor,
135
    block_size: int,
136
    max_seq_len: int,
137
138
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
139
140
    k_scale: float,
    v_scale: float,
141
142
143
144
145
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
146
) -> None:
147
    torch.ops._C.paged_attention_v2(
148
149
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
150
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
151
152
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
153
154


155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def paged_attention_rocm(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
    block_size: int,
    max_seq_len: int,
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
171
172
    k_scale: float,
    v_scale: float,
173
174
175
176
177
) -> None:
    torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
                                      key_cache, value_cache, num_kv_heads,
                                      scale, block_tables, seq_lens,
                                      block_size, max_seq_len, alibi_slopes,
178
                                      kv_cache_dtype, k_scale, v_scale)
179
180


181
182
183
184
185
186
187
188
189
# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
190
191
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
192
193
194
195
196
197
198


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
199
200
201
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
202
203
204
205
206


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
207
    torch.ops._C.rms_norm(out, input, weight, epsilon)
208
209
210
211


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
212
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
213
214


215
216
217
218
219
220
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
                           input_tokens: torch.Tensor,
                           sampled_token_ids: torch.Tensor,
                           input_positions: torch.Tensor,
                           seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                           block_tables: torch.Tensor) -> None:
221
    """Advance a step on GPU for existing inputs for a multi-step runner"""
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
                                               block_size, input_tokens,
                                               sampled_token_ids,
                                               input_positions, seq_lens,
                                               slot_mapping, block_tables)


def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
                            input_tokens: torch.Tensor,
                            sampled_token_ids: torch.Tensor,
                            input_positions: torch.Tensor,
                            seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                            block_tables: torch.Tensor,
                            paged_kv_indices: torch.Tensor,
                            paged_kv_indptr: torch.Tensor,
                            paged_kv_last_page_len: torch.Tensor,
                            block_table_bound: torch.Tensor) -> None:

    return torch.ops._C.advance_step_flashinfer(
        num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
        input_positions, seq_lens, slot_mapping, block_tables,
        paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
        block_table_bound)
245
246


247
248
249
250
251
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
252
253
254
255
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_dequantize_triton)
        return awq_dequantize_triton(qweight, scales, zeros)
256
257
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
258
259
260
261


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
262
263
264
265
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_gemm_triton)
        return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
266
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
267
268
269
270
271
272
273


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
274
275
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
276
277


278
if hasattr(torch.ops._C, "gptq_gemm"):
279

280
    @register_fake("_C::gptq_gemm")
281
282
283
284
285
286
287
288
289
    def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_gptq_qzeros: torch.Tensor,
                        b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
                        use_exllama: bool, bit: int) -> torch.Tensor:
        return torch.empty((a.size(0), b_q_weight.size(1)),
                           dtype=a.dtype,
                           device=a.device)


290
291
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
292
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
293
294
295
296
297
298


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
299
300
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
301
302


303
304
305
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
306
307
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
308
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
309
                                            workspace, b_q_type, size_m,
310
                                            size_n, size_k)
311
312


313
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
314

315
    @register_fake("_C::gptq_marlin_24_gemm")
316
317
318
319
320
321
322
    def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                                  b_meta: torch.Tensor, b_scales: torch.Tensor,
                                  workspace: torch.Tensor,
                                  b_q_type: ScalarType, size_m: int,
                                  size_n: int, size_k: int) -> torch.Tensor:
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

323
    @register_fake("_C::gptq_marlin_gemm")
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    def _gptq_marlin_gemm_fake(a: torch.Tensor,
                               b_q_weight: torch.Tensor,
                               b_scales: torch.Tensor,
                               b_zeros: torch.Tensor,
                               g_idx: torch.Tensor,
                               perm: torch.Tensor,
                               workspace: torch.Tensor,
                               b_q_type: ScalarType,
                               size_m: int,
                               size_n: int,
                               size_k: int,
                               is_k_full: bool,
                               has_zp: bool = False,
                               use_fp32_reduce: bool = False) -> torch.Tensor:
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

340
    @register_fake("_C::ggml_dequantize")
341
342
343
344
    def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
                              n: int) -> torch.Tensor:
        return torch.empty((m, n), dtype=torch.float16, device=W.device)

345
    @register_fake("_C::ggml_mul_mat_vec_a8")
346
347
348
349
350
351
352
353
    def _ggml_mul_mat_vec_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: int,
    ) -> torch.Tensor:
        return torch.empty((1, row), dtype=torch.float16, device=W.device)

354
    @register_fake("_C::ggml_mul_mat_a8")
355
356
357
358
359
360
361
362
363
    def _ggml_mul_mat_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: int,
    ) -> torch.Tensor:
        batch = X.size(0)
        return torch.empty((batch, row), dtype=torch.float16, device=W.device)

364
    @register_fake("_C::marlin_qqq_gemm")
365
366
367
368
369
370
371
372
373
    def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              s_tok: torch.Tensor, s_ch: torch.Tensor,
                              s_group: torch.Tensor, workspace: torch.Tensor,
                              size_m: int, size_n: int,
                              size_k: int) -> torch.Tensor:
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

374
    @register_fake("_C::marlin_gemm")
375
376
377
378
379
380
381
382
    def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                          b_scales: torch.Tensor, workspace: torch.Tensor,
                          size_m: int, size_n: int,
                          size_k: int) -> torch.Tensor:
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

383
    @register_fake("_C::awq_dequantize")
384
385
386
387
388
389
390
391
392
393
    def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
                             zeros: torch.Tensor, split_k_iters: int, thx: int,
                             thy: int) -> torch.Tensor:
        in_c = qweight.size(0)
        qout_c = qweight.size(1)
        out_c = qout_c * 8
        return torch.empty((in_c, out_c),
                           dtype=scales.dtype,
                           device=scales.device)

394
    @register_fake("_C::awq_gemm")
395
396
397
398
399
400
401
402
    def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
                       qzeros: torch.Tensor, scales: torch.Tensor,
                       split_k_iters: int) -> torch.Tensor:
        num_in_feats = input.size(0)
        return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
                           dtype=input.dtype,
                           device=input.device).sum(0)

403
    @register_fake("_C::aqlm_gemm")
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
                        codebooks: torch.Tensor, scales: torch.Tensor,
                        codebook_partition_sizes: List[int],
                        bias: Optional[torch.Tensor]) -> torch.Tensor:
        out_features = codes.size(0) * codebooks.size(2)
        flat_input = input.reshape((-1, input.size(-1)))
        flat_output = torch.empty((flat_input.size(0), out_features),
                                  dtype=input.dtype,
                                  device=input.device)

        output_sizes = list(input.shape)
        output_sizes.pop()
        output_sizes.append(-1)
        return flat_output.reshape(tuple(output_sizes))

419
    @register_fake("_C::aqlm_dequant")
420
421
422
423
424
425
426
427
428
    def _aqlm_dequant_fake(
            codes: torch.Tensor, codebooks: torch.Tensor,
            codebook_partition_sizes: List[int]) -> torch.Tensor:
        in_features = codes.size(1) * 8
        out_features = codes.size(0)
        return torch.empty((out_features, in_features),
                           dtype=codebooks.dtype,
                           device=codebooks.device)

429
    @register_fake("_C::fp8_marlin_gemm")
430
431
432
433
434
435
    def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              b_scales: torch.Tensor, workspace: torch.Tensor,
                              num_bits: int, size_m: int, size_n: int,
                              size_k: int) -> torch.Tensor:
        return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

436
    @register_fake("_C::machete_gemm")
437
438
    def machete_gemm_fake(
        a: torch.Tensor,
439
440
        # Should be the tensor returned by machete_prepack_B
        b_q: torch.Tensor,
441
442
443
444
445
446
447
448
449
450
451
452
453
        b_type: ScalarType,
        b_scales: Optional[torch.Tensor] = None,
        b_zeros: Optional[torch.Tensor] = None,
        b_group_size: Optional[int] = None,
        c: Optional[torch.Tensor] = None,
        alpha: Optional[float] = None,
        beta: Optional[float] = None,
        schedule: Optional[str] = None,
    ) -> torch.Tensor:
        m = a.size(0)
        n = b_q.size(1)
        return torch.empty((m, n), device=a.device, dtype=a.dtype)

454
    @register_fake("_C::machete_prepack_B")
455
456
    def machete_prepack_B_fake(b_q_weight: torch.Tensor,
                               b_type: ScalarType) -> torch.Tensor:
457
458
        return torch.empty_like(b_q_weight,
                                memory_format=torch.contiguous_format)
459

460
    @register_fake("_C::causal_conv1d_fwd")
461
462
    def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
                               bias_: Optional[torch.Tensor],
463
464
465
466
                               conv_states: Optional[torch.Tensor],
                               cu_seq_len: Optional[torch.Tensor],
                               cache_indices: Optional[torch.Tensor],
                               has_initial_state: Optional[torch.Tensor],
467
468
                               silu_activation: bool, pad_slot_id: int):
        return None
469

470
    @register_fake("_C::causal_conv1d_update")
471
472
473
474
475
476
477
478
    def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
                                  weight: torch.Tensor,
                                  bias_: Optional[torch.Tensor],
                                  silu_activation: bool,
                                  cache_seqlens: Optional[torch.Tensor],
                                  conv_state_indices: Optional[torch.Tensor],
                                  pad_slot_id: int) -> None:
        return None
479

480
    @register_fake("_C::selective_scan_fwd")
481
482
483
484
485
486
487
488
489
    def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
                                A: torch.Tensor, B: torch.Tensor,
                                C: torch.Tensor, D_: Optional[torch.Tensor],
                                z_: Optional[torch.Tensor],
                                delta_bias_: Optional[torch.Tensor],
                                delta_softplus: bool,
                                cu_seq_len: Optional[torch.Tensor],
                                cache_indices: Optional[torch.Tensor],
                                has_initial_state: Optional[torch.Tensor],
490
491
                                ssm_states: Optional[torch.Tensor],
                                pad_slot_id: int) -> None:
492
        return None
493
494


495
# cutlass
496
497
498
499
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


500
501
502
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
503
                      scale_b: torch.Tensor,
504
                      out_dtype: torch.dtype,
505
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
506
507
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
508
509
    assert bias is None or bias.shape[0] == b.shape[
        1] and bias.dtype == out_dtype
510
511
512
513
514

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

515
516
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

517
518
519
    return out


520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def cutlass_scaled_mm_azp(a: torch.Tensor,
                          b: torch.Tensor,
                          scale_a: torch.Tensor,
                          scale_b: torch.Tensor,
                          out_dtype: torch.dtype,
                          azp_adj: torch.Tensor,
                          azp: Optional[torch.Tensor] = None,
                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.numel(
    ) == b.shape[1] and bias.dtype == out_dtype

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
                                       azp, bias)
    return out


542
543
544
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
545
              codebook_partition_sizes: List[int],
546
              bias: Optional[torch.Tensor]) -> torch.Tensor:
547
548
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
549
550
551


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
552
                 codebook_partition_sizes: List[int]) -> torch.Tensor:
553
554
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
555
556


557
558
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
559
560
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
561
562
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
563
564


565
566
567
568
569
570
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


571
572
573
574
575
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                           size_k: int, size_n: int,
                           num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
576
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
577
578
579
580
581
582
583
584
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
                                                    size_k, size_n, num_bits)
    return output


585
586
587
588
589
590
591
592
593
594
595
596
597
598
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                          size_k: int, size_n: int,
                          num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
                                                   size_n, num_bits)
    return output


599
600
601
602
603
604
605
606
607
608
609
610
611
612
def gptq_marlin_gemm(a: torch.Tensor,
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
                     b_zeros: torch.Tensor,
                     g_idx: torch.Tensor,
                     perm: torch.Tensor,
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
                     is_k_full: bool,
                     has_zp: bool = False,
                     use_fp32_reduce: bool = False) -> torch.Tensor:
613
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
614
                                         g_idx, perm, workspace, b_q_type,
615
                                         size_m, size_n, size_k, is_k_full,
616
                                         has_zp, use_fp32_reduce)
617
618


619
620
621
622
623
624
625
626
627
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
    return torch.ops._C.machete_supported_schedules(b_type)


def machete_gemm(
    a: torch.Tensor,
    b_q: torch.Tensor,  # Should be the tensor returned by machete_prepack_B
    b_type: ScalarType,
    b_scales: Optional[torch.Tensor] = None,
    b_zeros: Optional[torch.Tensor] = None,
    b_group_size: Optional[int] = None,
    c: Optional[torch.Tensor] = None,
    alpha: Optional[float] = None,
    beta: Optional[float] = None,
    schedule: Optional[str] = None,
) -> torch.Tensor:
    return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
                                     b_group_size, c, alpha, beta, schedule)


def machete_prepack_B(b_q_weight: torch.Tensor,
                      b_type: ScalarType) -> torch.Tensor:
    return torch.ops._C.machete_prepack_B(b_q_weight, b_type)


654
if hasattr(torch.ops._C, "permute_cols"):
655

656
    @register_fake("_C::permute_cols")
657
658
659
660
661
662
663
664
665
    def _permute_cols_fake(a: torch.Tensor,
                           perm: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(a)


def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
    return torch.ops._C.permute_cols(a, perm)


666
# fp8
667
668
669
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
670
    num_token_padding: Optional[int] = None,
671
    scale_ub: Optional[torch.Tensor] = None,
672
    use_per_token_if_dynamic: bool = False,
673
) -> Tuple[torch.Tensor, torch.Tensor]:
674
675
676
677
678
679
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
680
    optional padding of the output tensors for downstream kernels that
681
682
683
684
685
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
686
687
        scale_ub: Optional upper bound for scaling factor in dynamic 
            per token case
688
        num_token_padding: If specified, pad the first dimension
689
            of the output to at least this value.
690
691
        use_per_token_if_dynamic: Whether to do per_tensor or per_token 
            in the dynamic quantization case.
692
693
694
695
696

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
697
698
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
699
    shape: Union[Tuple[int, int], torch.Size] = input.shape
700
701
702
    # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
    out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
        else torch.float8_e4m3fn
703
704
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
705
    output = torch.empty(shape, device=input.device, dtype=out_dtype)
706

707
    if scale is None:
708
        if use_per_token_if_dynamic:
709
            scale = torch.empty((shape[0], 1),
710
711
712
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
713
                output, input, scale, scale_ub)
714
715
716
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
717
    else:
718
719
        # num_token_padding not implemented for this case
        assert (scale.numel() == 1 or num_token_padding is None)
720
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
721

722
    return output, scale
723
724


725
# int8
726
def scaled_int8_quant(
727
728
729
730
731
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    azp: Optional[torch.Tensor] = None,
    symmetric: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
732
    """
733
    Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
734
735
736

    Args:
        input: The input tensor to be quantized to int8.
737
738
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
739
740
741
        azp: Optional zero-point for the int8 quantization.
            Must be provided for asymmetric quantization if `scale` is provided.
        symmetric: Whether to use symmetric quantization (scale only, azp ignored).
742
743

    Returns:
744
      Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
745
    """
746
747
748
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
749
750
751
752
753
        assert symmetric == (
            azp is
            None), "azp must only be provided for asymmetric quantization."
        torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
        return output, scale, None
754
755
756
757
758

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
759
760
761
762
763
    input_azp = None if symmetric else torch.empty_like(input_scales,
                                                        dtype=torch.int32)
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
                                           input_azp)
    return output, input_scales, input_azp
764
765


766
767
768
769
770
771
772
773
774
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


775
# gguf
776
777
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
                    n: int) -> torch.Tensor:
778
779
780
781
782
783
784
785
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
786
) -> torch.Tensor:
787
788
789
790
791
792
793
794
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
795
) -> torch.Tensor:
796
797
798
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


799
800
801
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                      bias_: Optional[torch.Tensor],
802
803
804
805
                      conv_states: Optional[torch.Tensor],
                      query_start_loc: Optional[torch.Tensor],
                      cache_indices: Optional[torch.Tensor],
                      has_initial_state: Optional[torch.Tensor],
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
                      silu_activation: bool, pad_slot_id: int):
    torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
                                   query_start_loc, cache_indices,
                                   has_initial_state, silu_activation,
                                   pad_slot_id)


def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
                         weight: torch.Tensor, bias_: Optional[torch.Tensor],
                         silu_activation: bool,
                         cache_seqlens: Optional[torch.Tensor],
                         conv_state_indices: Optional[torch.Tensor],
                         pad_slot_id: int):
    torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
                                      silu_activation, cache_seqlens,
                                      conv_state_indices, pad_slot_id)


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
                       B: torch.Tensor, C: torch.Tensor,
                       D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
                       delta_bias_: Optional[torch.Tensor],
                       delta_softplus: bool,
                       query_start_loc: Optional[torch.Tensor],
                       cache_indices: Optional[torch.Tensor],
                       has_initial_state: Optional[torch.Tensor],
                       ssm_states: torch.Tensor, pad_slot_id: int):
833
834
835
    torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
                                    delta_softplus, query_start_loc,
                                    cache_indices, has_initial_state,
836
                                    ssm_states, pad_slot_id)
837
838


839
840
841
842
843
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
844
845
846
847
848
849
850
851
852
853
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
854
855


856
857
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

858
    @register_fake("_moe_C::marlin_gemm_moe")
859
860
861
862
    def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
                             sorted_ids: torch.Tensor,
                             topk_weights: torch.Tensor,
                             topk_ids: torch.Tensor, b_scales: torch.Tensor,
863
864
865
866
867
868
                             b_zero_points: torch.Tensor, g_idx: torch.Tensor,
                             perm: torch.Tensor, workspace: torch.Tensor,
                             b_q_type: ScalarType, size_m: int, size_n: int,
                             size_k: int, is_k_full: bool, num_experts: int,
                             topk: int, moe_block_size: int,
                             replicate_input: bool,
869
870
871
872
873
874
                             apply_weights: bool) -> torch.Tensor:
        return torch.empty((size_m, topk, size_n),
                           dtype=a.dtype,
                           device=a.device)


875
876
877
878
879
880
881
def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
882
883
    k_scale: float,
    v_scale: float,
884
) -> None:
885
886
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
887
                                             kv_cache_dtype, k_scale, v_scale)
888
889


890
891
892
893
894
895
896
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
897
898
    k_scale: float,
    v_scale: float,
899
) -> None:
900
901
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
902
903
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
904
905


906
907
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
908
                block_mapping: torch.Tensor) -> None:
909
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
910
911
912


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
913
                block_mapping: torch.Tensor) -> None:
914
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
915
916


917
918
919
920
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

945

946
947
948
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
949

950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
988
                for arg in v.__annotations__.values()):
989
990
991
992
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type