_custom_ops.py 35.8 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import List, Optional, Tuple, Union
4
5
6

import torch

7
import vllm.envs as envs
8
from vllm._core_ext import ScalarType
9
from vllm.logger import init_logger
10
from vllm.platforms import current_platform
11
12
13

logger = init_logger(__name__)

14
15
16
17
18
if not current_platform.is_tpu():
    try:
        import vllm._C
    except ImportError as e:
        logger.warning("Failed to import from vllm._C with %r", e)
19

20
21
22
if current_platform.is_rocm():
    import vllm._rocm_C  # noqa: F401

23
with contextlib.suppress(ImportError):
24
    import vllm._moe_C  # noqa: F401
25

26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


46
47
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
48
    torch.ops._C.silu_and_mul(out, x)
49
50
51


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
52
    torch.ops._C.gelu_and_mul(out, x)
53
54
55


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
56
    torch.ops._C.gelu_tanh_and_mul(out, x)
57
58
59


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
60
    torch.ops._C.gelu_fast(out, x)
61
62
63


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
64
    torch.ops._C.gelu_new(out, x)
65
66


67
68
69
70
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


71
72
73
74
75
76
77
78
79
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
80
    seq_lens: torch.Tensor,
81
    block_size: int,
82
    max_seq_len: int,
83
84
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
85
86
    k_scale: float,
    v_scale: float,
87
88
89
90
91
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
92
) -> None:
93
    torch.ops._C.paged_attention_v1(
94
95
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
96
97
98
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
99
100
101
102
103
104
105
106
107
108
109
110
111


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
112
    seq_lens: torch.Tensor,
113
    block_size: int,
114
    max_seq_len: int,
115
116
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
117
118
    k_scale: float,
    v_scale: float,
119
120
121
122
123
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
124
) -> None:
125
    torch.ops._C.paged_attention_v2(
126
127
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
128
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
129
130
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
131
132


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def paged_attention_rocm(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
    block_size: int,
    max_seq_len: int,
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
149
150
    k_scale: float,
    v_scale: float,
151
152
153
154
155
) -> None:
    torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
                                      key_cache, value_cache, num_kv_heads,
                                      scale, block_tables, seq_lens,
                                      block_size, max_seq_len, alibi_slopes,
156
                                      kv_cache_dtype, k_scale, v_scale)
157
158


159
160
161
162
163
164
165
166
167
# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
168
169
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
170
171
172
173
174
175
176


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
177
178
179
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
180
181
182
183
184


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
185
    torch.ops._C.rms_norm(out, input, weight, epsilon)
186
187
188
189


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
190
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
191
192


193
194
195
196
197
198
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
                           input_tokens: torch.Tensor,
                           sampled_token_ids: torch.Tensor,
                           input_positions: torch.Tensor,
                           seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                           block_tables: torch.Tensor) -> None:
199
    """Advance a step on GPU for existing inputs for a multi-step runner"""
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
                                               block_size, input_tokens,
                                               sampled_token_ids,
                                               input_positions, seq_lens,
                                               slot_mapping, block_tables)


def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
                            input_tokens: torch.Tensor,
                            sampled_token_ids: torch.Tensor,
                            input_positions: torch.Tensor,
                            seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                            block_tables: torch.Tensor,
                            paged_kv_indices: torch.Tensor,
                            paged_kv_indptr: torch.Tensor,
                            paged_kv_last_page_len: torch.Tensor,
                            block_table_bound: torch.Tensor) -> None:

    return torch.ops._C.advance_step_flashinfer(
        num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
        input_positions, seq_lens, slot_mapping, block_tables,
        paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
        block_table_bound)
223
224


225
226
227
228
229
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
230
231
232
233
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_dequantize_triton)
        return awq_dequantize_triton(qweight, scales, zeros)
234
235
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
236
237
238
239


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
240
241
242
243
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_gemm_triton)
        return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
244
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
245
246
247
248
249
250
251


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
252
253
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
254
255


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# TODO: has to be a better way to do this
try:
    torch.ops._C.gptq_gemm  # noqa B018

    @torch.library.register_fake("_C::gptq_gemm")
    def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_gptq_qzeros: torch.Tensor,
                        b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
                        use_exllama: bool, bit: int) -> torch.Tensor:
        return torch.empty((a.size(0), b_q_weight.size(1)),
                           dtype=a.dtype,
                           device=a.device)
except Exception:
    pass


272
273
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
274
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
275
276
277
278
279
280


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
281
282
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
283
284


285
286
287
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
288
289
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
290
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
291
                                            workspace, b_q_type, size_m,
292
                                            size_n, size_k)
293
294


295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
# TODO: has to be a better way to do this
try:
    torch.ops._C.gptq_marlin_24_gemm  # noqa B018

    @torch.library.register_fake("_C::gptq_marlin_24_gemm")
    def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                                  b_meta: torch.Tensor, b_scales: torch.Tensor,
                                  workspace: torch.Tensor,
                                  b_q_type: ScalarType, size_m: int,
                                  size_n: int, size_k: int) -> torch.Tensor:
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

    @torch.library.register_fake("_C::gptq_marlin_gemm")
    def _gptq_marlin_gemm_fake(a: torch.Tensor,
                               b_q_weight: torch.Tensor,
                               b_scales: torch.Tensor,
                               b_zeros: torch.Tensor,
                               g_idx: torch.Tensor,
                               perm: torch.Tensor,
                               workspace: torch.Tensor,
                               b_q_type: ScalarType,
                               size_m: int,
                               size_n: int,
                               size_k: int,
                               is_k_full: bool,
                               has_zp: bool = False,
                               use_fp32_reduce: bool = False) -> torch.Tensor:
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

    @torch.library.register_fake("_C::ggml_dequantize")
    def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
                              n: int) -> torch.Tensor:
        return torch.empty((m, n), dtype=torch.float16, device=W.device)

    @torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
    def _ggml_mul_mat_vec_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: int,
    ) -> torch.Tensor:
        return torch.empty((1, row), dtype=torch.float16, device=W.device)

    @torch.library.register_fake("_C::ggml_mul_mat_a8")
    def _ggml_mul_mat_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
        row: int,
    ) -> torch.Tensor:
        batch = X.size(0)
        return torch.empty((batch, row), dtype=torch.float16, device=W.device)

    @torch.library.register_fake("_C::marlin_qqq_gemm")
    def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              s_tok: torch.Tensor, s_ch: torch.Tensor,
                              s_group: torch.Tensor, workspace: torch.Tensor,
                              size_m: int, size_n: int,
                              size_k: int) -> torch.Tensor:
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

    @torch.library.register_fake("_C::marlin_gemm")
    def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                          b_scales: torch.Tensor, workspace: torch.Tensor,
                          size_m: int, size_n: int,
                          size_k: int) -> torch.Tensor:
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

    @torch.library.register_fake("_C::awq_dequantize")
    def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
                             zeros: torch.Tensor, split_k_iters: int, thx: int,
                             thy: int) -> torch.Tensor:
        in_c = qweight.size(0)
        qout_c = qweight.size(1)
        out_c = qout_c * 8
        return torch.empty((in_c, out_c),
                           dtype=scales.dtype,
                           device=scales.device)

    @torch.library.register_fake("_C::awq_gemm")
    def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
                       qzeros: torch.Tensor, scales: torch.Tensor,
                       split_k_iters: int) -> torch.Tensor:
        num_in_feats = input.size(0)
        return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
                           dtype=input.dtype,
                           device=input.device).sum(0)

    @torch.library.register_fake("_C::aqlm_gemm")
    def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
                        codebooks: torch.Tensor, scales: torch.Tensor,
                        codebook_partition_sizes: List[int],
                        bias: Optional[torch.Tensor]) -> torch.Tensor:
        out_features = codes.size(0) * codebooks.size(2)
        flat_input = input.reshape((-1, input.size(-1)))
        flat_output = torch.empty((flat_input.size(0), out_features),
                                  dtype=input.dtype,
                                  device=input.device)

        output_sizes = list(input.shape)
        output_sizes.pop()
        output_sizes.append(-1)
        return flat_output.reshape(tuple(output_sizes))

    @torch.library.register_fake("_C::aqlm_dequant")
    def _aqlm_dequant_fake(
            codes: torch.Tensor, codebooks: torch.Tensor,
            codebook_partition_sizes: List[int]) -> torch.Tensor:
        in_features = codes.size(1) * 8
        out_features = codes.size(0)
        return torch.empty((out_features, in_features),
                           dtype=codebooks.dtype,
                           device=codebooks.device)

    @torch.library.register_fake("_C::fp8_marlin_gemm")
    def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              b_scales: torch.Tensor, workspace: torch.Tensor,
                              num_bits: int, size_m: int, size_n: int,
                              size_k: int) -> torch.Tensor:
        return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

    @torch.library.register_fake("_C::machete_gemm")
    def machete_gemm_fake(
        a: torch.Tensor,
        b_q: torch.
        Tensor,  # Should be the tensor returned by machete_prepack_B
        b_type: ScalarType,
        b_scales: Optional[torch.Tensor] = None,
        b_zeros: Optional[torch.Tensor] = None,
        b_group_size: Optional[int] = None,
        c: Optional[torch.Tensor] = None,
        alpha: Optional[float] = None,
        beta: Optional[float] = None,
        schedule: Optional[str] = None,
    ) -> torch.Tensor:
        m = a.size(0)
        n = b_q.size(1)
        return torch.empty((m, n), device=a.device, dtype=a.dtype)

    @torch.library.register_fake("_C::machete_prepack_B")
    def machete_prepack_B_fake(b_q_weight: torch.Tensor,
                               b_type: ScalarType) -> torch.Tensor:
        return torch.empty_like(b_q_weight)

    @torch.library.register_fake("_C::causal_conv1d_fwd")
    def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
                               bias_: Optional[torch.Tensor],
                               seq_idx_: Optional[torch.Tensor],
                               initial_states_: Optional[torch.Tensor],
                               final_states_out_: Optional[torch.Tensor],
                               silu_activation: bool) -> torch.Tensor:
        return torch.empty_like(x)

    @torch.library.register_fake("_C::causal_conv1d_update")
    def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
                                  weight: torch.Tensor,
                                  bias_: Optional[torch.Tensor],
                                  silu_activation: bool) -> torch.Tensor:
        return torch.empty_like(x)

    @torch.library.register_fake("_C::selective_scan_fwd")
    def selective_scan_fwd_fake(
            u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
            B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
            z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
            delta_softplus: bool, index_: Optional[torch.Tensor],
            x: Optional[torch.Tensor]) -> List[torch.Tensor]:
        a = torch.empty_like(u)
        if x is not None:
            b = x
        else:
            b = torch.empty((u.size(0), u.size(1), A.size(1)),
                            dtype=u.dtype,
                            device=u.device)
        if z_ is not None:
            c = torch.empty_like(z_)
            return [a, b, c]
        else:
            return [a, b]

except Exception:
    pass


483
# cutlass
484
485
486
487
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


488
489
490
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
491
                      scale_b: torch.Tensor,
492
                      out_dtype: torch.dtype,
493
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
494
495
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
496
497
    assert bias is None or bias.shape[0] == b.shape[
        1] and bias.dtype == out_dtype
498
499
500
501
502

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

503
504
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

505
506
507
    return out


508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
def cutlass_scaled_mm_azp(a: torch.Tensor,
                          b: torch.Tensor,
                          scale_a: torch.Tensor,
                          scale_b: torch.Tensor,
                          out_dtype: torch.dtype,
                          azp_adj: torch.Tensor,
                          azp: Optional[torch.Tensor] = None,
                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.numel(
    ) == b.shape[1] and bias.dtype == out_dtype

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
                                       azp, bias)
    return out


530
531
532
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
533
              codebook_partition_sizes: List[int],
534
              bias: Optional[torch.Tensor]) -> torch.Tensor:
535
536
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
537
538
539


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
540
                 codebook_partition_sizes: List[int]) -> torch.Tensor:
541
542
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
543
544


545
546
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
547
548
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
549
550
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
551
552


553
554
555
556
557
558
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


559
560
561
562
563
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                           size_k: int, size_n: int,
                           num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
564
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
565
566
567
568
569
570
571
572
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
                                                    size_k, size_n, num_bits)
    return output


573
574
575
576
577
578
579
580
581
582
583
584
585
586
def gptq_marlin_gemm(a: torch.Tensor,
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
                     b_zeros: torch.Tensor,
                     g_idx: torch.Tensor,
                     perm: torch.Tensor,
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
                     is_k_full: bool,
                     has_zp: bool = False,
                     use_fp32_reduce: bool = False) -> torch.Tensor:
587
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
588
                                         g_idx, perm, workspace, b_q_type,
589
                                         size_m, size_n, size_k, is_k_full,
590
                                         has_zp, use_fp32_reduce)
591
592


593
594
595
596
597
598
599
600
601
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
    return torch.ops._C.machete_supported_schedules(b_type)


def machete_gemm(
    a: torch.Tensor,
    b_q: torch.Tensor,  # Should be the tensor returned by machete_prepack_B
    b_type: ScalarType,
    b_scales: Optional[torch.Tensor] = None,
    b_zeros: Optional[torch.Tensor] = None,
    b_group_size: Optional[int] = None,
    c: Optional[torch.Tensor] = None,
    alpha: Optional[float] = None,
    beta: Optional[float] = None,
    schedule: Optional[str] = None,
) -> torch.Tensor:
    return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
                                     b_group_size, c, alpha, beta, schedule)


def machete_prepack_B(b_q_weight: torch.Tensor,
                      b_type: ScalarType) -> torch.Tensor:
    return torch.ops._C.machete_prepack_B(b_q_weight, b_type)


628
# fp8
629
630
631
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
632
    num_token_padding: Optional[int] = None,
633
    scale_ub: Optional[torch.Tensor] = None,
634
    use_per_token_if_dynamic: bool = False,
635
) -> Tuple[torch.Tensor, torch.Tensor]:
636
637
638
639
640
641
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
642
    optional padding of the output tensors for downstream kernels that
643
644
645
646
647
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
648
649
        scale_ub: Optional upper bound for scaling factor in dynamic 
            per token case
650
        num_token_padding: If specified, pad the first dimension
651
            of the output to at least this value.
652
653
        use_per_token_if_dynamic: Whether to do per_tensor or per_token 
            in the dynamic quantization case.
654
655
656
657
658

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
659
660
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
661
    shape: Union[Tuple[int, int], torch.Size] = input.shape
662
663
664
    # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
    out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
        else torch.float8_e4m3fn
665
666
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
667
    output = torch.empty(shape, device=input.device, dtype=out_dtype)
668

669
    if scale is None:
670
        if use_per_token_if_dynamic:
671
            scale = torch.empty((shape[0], 1),
672
673
674
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
675
                output, input, scale, scale_ub)
676
677
678
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
679
    else:
680
681
        # num_token_padding not implemented for this case
        assert (scale.numel() == 1 or num_token_padding is None)
682
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
683

684
    return output, scale
685
686


687
# int8
688
def scaled_int8_quant(
689
690
691
692
693
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    azp: Optional[torch.Tensor] = None,
    symmetric: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
694
    """
695
    Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
696
697
698

    Args:
        input: The input tensor to be quantized to int8.
699
700
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
701
702
703
        azp: Optional zero-point for the int8 quantization.
            Must be provided for asymmetric quantization if `scale` is provided.
        symmetric: Whether to use symmetric quantization (scale only, azp ignored).
704
705

    Returns:
706
      Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
707
    """
708
709
710
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
711
712
713
714
715
        assert symmetric == (
            azp is
            None), "azp must only be provided for asymmetric quantization."
        torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
        return output, scale, None
716
717
718
719
720

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
721
722
723
724
725
    input_azp = None if symmetric else torch.empty_like(input_scales,
                                                        dtype=torch.int32)
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
                                           input_azp)
    return output, input_scales, input_azp
726
727


728
729
730
731
732
733
734
735
736
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


737
# gguf
738
739
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
                    n: int) -> torch.Tensor:
740
741
742
743
744
745
746
747
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
748
) -> torch.Tensor:
749
750
751
752
753
754
755
756
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
757
) -> torch.Tensor:
758
759
760
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


761
762
763
764
765
766
767
768
769
770
771
772
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                      bias_: Optional[torch.Tensor],
                      seq_idx_: Optional[torch.Tensor],
                      initial_states_: Optional[torch.Tensor],
                      final_states_out_: Optional[torch.Tensor],
                      silu_activation: bool) -> torch.Tensor:
    return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
                                          initial_states_, final_states_out_,
                                          silu_activation)


773
774
775
776
777
778
779
780
def causal_conv1d_update(
    x: torch.Tensor,
    conv_state: torch.Tensor,
    weight: torch.Tensor,
    bias_: Optional[torch.Tensor],
    silu_activation: bool,
    conv_state_indices: Optional[torch.Tensor],
) -> torch.Tensor:
781
    return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
782
783
                                             silu_activation,
                                             conv_state_indices)
784
785
786
787
788
789
790
791
792
793
794
795
796


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
                       B: torch.Tensor, C: torch.Tensor,
                       D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
                       delta_bias_: Optional[torch.Tensor],
                       delta_softplus: bool, index_: Optional[torch.Tensor],
                       x: Optional[torch.Tensor]) -> List[torch.Tensor]:
    return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
                                           delta_bias_, delta_softplus, index_,
                                           x)


797
798
799
800
801
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
802
803
804
805
806
807
808
809
810
811
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
812
813
814
815
816
817
818
819
820


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
821
822
    k_scale: float,
    v_scale: float,
823
) -> None:
824
825
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
826
                                             kv_cache_dtype, k_scale, v_scale)
827
828


829
830
831
832
833
834
835
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
836
837
    k_scale: float,
    v_scale: float,
838
) -> None:
839
840
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
841
842
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
843
844


845
846
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
847
                block_mapping: torch.Tensor) -> None:
848
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
849
850
851


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
852
                block_mapping: torch.Tensor) -> None:
853
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
854
855


856
857
858
859
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

884

885
886
887
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
888

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
927
                for arg in v.__annotations__.values()):
928
929
930
931
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type