_custom_ops.py 39.1 KB
Newer Older
1
import contextlib
2
import functools
3
import importlib
4
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
5
6

import torch
7
import torch.library
8

9
import vllm.envs as envs
10
from vllm.logger import init_logger
11
from vllm.platforms import current_platform
12
from vllm.scalar_type import ScalarType
13
14
15

logger = init_logger(__name__)

16
if not current_platform.is_tpu() and not current_platform.is_hpu():
17
18
19
20
    try:
        import vllm._C
    except ImportError as e:
        logger.warning("Failed to import from vllm._C with %r", e)
21

22
supports_moe_ops = False
23
with contextlib.suppress(ImportError):
24
    import vllm._moe_C  # noqa: F401
25
    supports_moe_ops = True
26

27
28
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING or current_platform.is_neuron():
29
30
31
32
33
34
35
36
37

    def register_fake(fn):
        return lambda name: fn
else:
    try:
        from torch.library import register_fake
    except ImportError:
        from torch.library import impl_abstract as register_fake

38

39
40
41
42
43
44
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
45
46
47
48
49
50
51
52
53

        except NotImplementedError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Not implemented or built, mostly likely because the current current device "
                "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
                "incorrectly while building)")
            logger.error(msg, fn.__name__, e)
            raise NotImplementedError(msg % (fn.__name__, e)) from e
54
55
56
57
58
59
60
61
62
63
64
65
66
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


67
68
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
69
    torch.ops._C.silu_and_mul(out, x)
70
71
72


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
73
    torch.ops._C.gelu_and_mul(out, x)
74
75
76


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
77
    torch.ops._C.gelu_tanh_and_mul(out, x)
78
79


80
81
82
83
84
85
def fatrelu_and_mul(out: torch.Tensor,
                    x: torch.Tensor,
                    threshold: float = 0.0) -> None:
    torch.ops._C.fatrelu_and_mul(out, x, threshold)


86
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
87
    torch.ops._C.gelu_fast(out, x)
88
89
90


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
91
    torch.ops._C.gelu_new(out, x)
92
93


94
95
96
97
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


98
99
100
101
102
103
104
105
106
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
107
    seq_lens: torch.Tensor,
108
    block_size: int,
109
    max_seq_len: int,
110
111
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
112
113
    k_scale: float,
    v_scale: float,
114
115
116
117
118
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
119
) -> None:
120
    torch.ops._C.paged_attention_v1(
121
122
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
123
124
125
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
126
127
128
129
130
131
132
133
134
135
136
137
138


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
139
    seq_lens: torch.Tensor,
140
    block_size: int,
141
    max_seq_len: int,
142
143
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
144
145
    k_scale: float,
    v_scale: float,
146
147
148
149
150
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
151
) -> None:
152
    torch.ops._C.paged_attention_v2(
153
154
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
155
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
156
157
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
158
159


160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def paged_attention_rocm(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
    block_size: int,
    max_seq_len: int,
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
176
177
    k_scale: float,
    v_scale: float,
178
179
180
181
182
) -> None:
    torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
                                      key_cache, value_cache, num_kv_heads,
                                      scale, block_tables, seq_lens,
                                      block_size, max_seq_len, alibi_slopes,
183
                                      kv_cache_dtype, k_scale, v_scale)
184
185


186
187
188
189
190
191
192
193
194
# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
195
196
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
197
198
199
200
201
202
203


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
204
205
206
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
207
208
209
210
211


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
212
    torch.ops._C.rms_norm(out, input, weight, epsilon)
213
214
215
216


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
217
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
218
219


220
221
222
223
224
225
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
                           input_tokens: torch.Tensor,
                           sampled_token_ids: torch.Tensor,
                           input_positions: torch.Tensor,
                           seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                           block_tables: torch.Tensor) -> None:
226
    """Advance a step on GPU for existing inputs for a multi-step runner"""
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
                                               block_size, input_tokens,
                                               sampled_token_ids,
                                               input_positions, seq_lens,
                                               slot_mapping, block_tables)


def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
                            input_tokens: torch.Tensor,
                            sampled_token_ids: torch.Tensor,
                            input_positions: torch.Tensor,
                            seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
                            block_tables: torch.Tensor,
                            paged_kv_indices: torch.Tensor,
                            paged_kv_indptr: torch.Tensor,
                            paged_kv_last_page_len: torch.Tensor,
                            block_table_bound: torch.Tensor) -> None:

    return torch.ops._C.advance_step_flashinfer(
        num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
        input_positions, seq_lens, slot_mapping, block_tables,
        paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
        block_table_bound)
250
251


252
253
254
255
256
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
257
258
259
260
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_dequantize_triton)
        return awq_dequantize_triton(qweight, scales, zeros)
261
262
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
263
264
265
266


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
267
268
269
270
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_gemm_triton)
        return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
271
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
272
273
274
275
276
277
278


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
279
280
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
281
282


283
if hasattr(torch.ops._C, "gptq_gemm"):
284

285
    @register_fake("_C::gptq_gemm")
286
287
288
289
290
291
292
293
294
    def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_gptq_qzeros: torch.Tensor,
                        b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
                        use_exllama: bool, bit: int) -> torch.Tensor:
        return torch.empty((a.size(0), b_q_weight.size(1)),
                           dtype=a.dtype,
                           device=a.device)


295
296
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
297
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
298
299
300
301
302
303


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
304
305
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
306
307


308
309
310
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
311
312
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
313
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
314
                                            workspace, b_q_type.id, size_m,
315
                                            size_n, size_k)
316
317


318
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
319

320
    @register_fake("_C::gptq_marlin_24_gemm")
321
322
323
    def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                                  b_meta: torch.Tensor, b_scales: torch.Tensor,
                                  workspace: torch.Tensor,
324
325
326
                                  b_q_type: ScalarType, size_m: torch.SymInt,
                                  size_n: torch.SymInt,
                                  size_k: torch.SymInt) -> torch.Tensor:
327
328
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

329
    @register_fake("_C::gptq_marlin_gemm")
330
331
332
333
334
335
336
337
    def _gptq_marlin_gemm_fake(a: torch.Tensor,
                               b_q_weight: torch.Tensor,
                               b_scales: torch.Tensor,
                               b_zeros: torch.Tensor,
                               g_idx: torch.Tensor,
                               perm: torch.Tensor,
                               workspace: torch.Tensor,
                               b_q_type: ScalarType,
338
339
340
                               size_m: torch.SymInt,
                               size_n: torch.SymInt,
                               size_k: torch.SymInt,
341
342
                               is_k_full: bool,
                               has_zp: bool = False,
343
344
                               use_fp32_reduce: bool = False,
                               is_zp_float: bool = False) -> torch.Tensor:
345
346
        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

347
    @register_fake("_C::ggml_dequantize")
348
349
350
    def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
                              m: torch.SymInt,
                              n: torch.SymInt) -> torch.Tensor:
351
352
        return torch.empty((m, n), dtype=torch.float16, device=W.device)

353
    @register_fake("_C::ggml_mul_mat_vec_a8")
354
355
356
357
    def _ggml_mul_mat_vec_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
358
        row: torch.SymInt,
359
360
361
    ) -> torch.Tensor:
        return torch.empty((1, row), dtype=torch.float16, device=W.device)

362
    @register_fake("_C::ggml_mul_mat_a8")
363
364
365
366
    def _ggml_mul_mat_a8_fake(
        W: torch.Tensor,
        X: torch.Tensor,
        quant_type: int,
367
        row: torch.SymInt,
368
369
370
371
    ) -> torch.Tensor:
        batch = X.size(0)
        return torch.empty((batch, row), dtype=torch.float16, device=W.device)

372
    @register_fake("_C::marlin_qqq_gemm")
373
374
375
    def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              s_tok: torch.Tensor, s_ch: torch.Tensor,
                              s_group: torch.Tensor, workspace: torch.Tensor,
376
377
                              size_m: torch.SymInt, size_n: torch.SymInt,
                              size_k: torch.SymInt) -> torch.Tensor:
378
379
380
381
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

382
    @register_fake("_C::marlin_gemm")
383
384
    def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                          b_scales: torch.Tensor, workspace: torch.Tensor,
385
386
                          size_m: torch.SymInt, size_n: torch.SymInt,
                          size_k: torch.SymInt) -> torch.Tensor:
387
388
389
390
        return torch.empty((size_m, size_n),
                           dtype=torch.float16,
                           device=a.device)

391
    @register_fake("_C::awq_dequantize")
392
    def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
393
394
                             zeros: torch.Tensor, split_k_iters: torch.SymInt,
                             thx: int, thy: int) -> torch.Tensor:
395
396
397
398
399
400
401
        in_c = qweight.size(0)
        qout_c = qweight.size(1)
        out_c = qout_c * 8
        return torch.empty((in_c, out_c),
                           dtype=scales.dtype,
                           device=scales.device)

402
    @register_fake("_C::awq_gemm")
403
404
    def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
                       qzeros: torch.Tensor, scales: torch.Tensor,
405
                       split_k_iters: torch.SymInt) -> torch.Tensor:
406
407
408
409
410
        num_in_feats = input.size(0)
        return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
                           dtype=input.dtype,
                           device=input.device).sum(0)

411
    @register_fake("_C::aqlm_gemm")
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
                        codebooks: torch.Tensor, scales: torch.Tensor,
                        codebook_partition_sizes: List[int],
                        bias: Optional[torch.Tensor]) -> torch.Tensor:
        out_features = codes.size(0) * codebooks.size(2)
        flat_input = input.reshape((-1, input.size(-1)))
        flat_output = torch.empty((flat_input.size(0), out_features),
                                  dtype=input.dtype,
                                  device=input.device)

        output_sizes = list(input.shape)
        output_sizes.pop()
        output_sizes.append(-1)
        return flat_output.reshape(tuple(output_sizes))

427
    @register_fake("_C::aqlm_dequant")
428
429
430
431
432
433
434
435
436
    def _aqlm_dequant_fake(
            codes: torch.Tensor, codebooks: torch.Tensor,
            codebook_partition_sizes: List[int]) -> torch.Tensor:
        in_features = codes.size(1) * 8
        out_features = codes.size(0)
        return torch.empty((out_features, in_features),
                           dtype=codebooks.dtype,
                           device=codebooks.device)

437
    @register_fake("_C::fp8_marlin_gemm")
438
439
    def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
                              b_scales: torch.Tensor, workspace: torch.Tensor,
440
441
442
                              num_bits: int, size_m: torch.SymInt,
                              size_n: torch.SymInt,
                              size_k: torch.SymInt) -> torch.Tensor:
443
444
        return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

445
446
    @register_fake("_C::machete_mm")
    def machete_mm_fake(
447
        a: torch.Tensor,
448
        # b_q Should be the tensor returned by machete_prepack_B
449
        b_q: torch.Tensor,
450
        b_type: ScalarType,
451
452
453
        out_type: Optional[torch.dtype] = None,
        b_group_scales: Optional[torch.Tensor] = None,
        b_group_zeros: Optional[torch.Tensor] = None,
454
        b_group_size: Optional[int] = None,
455
456
        b_channel_scales: Optional[torch.Tensor] = None,
        a_token_scales: Optional[torch.Tensor] = None,
457
458
459
460
461
462
        schedule: Optional[str] = None,
    ) -> torch.Tensor:
        m = a.size(0)
        n = b_q.size(1)
        return torch.empty((m, n), device=a.device, dtype=a.dtype)

463
    @register_fake("_C::machete_prepack_B")
464
465
466
    def machete_prepack_B_fake(
            b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
            group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
467
468
        return torch.empty_like(b_q_weight,
                                memory_format=torch.contiguous_format)
469
470


471
# cutlass
472
473
474
475
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


476
477
478
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
479
                      scale_b: torch.Tensor,
480
                      out_dtype: torch.dtype,
481
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
482
483
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
484
485
    assert bias is None or bias.shape[0] == b.shape[
        1] and bias.dtype == out_dtype
486
487
488

    m = a.shape[0]
    n = b.shape[1]
489
490
491
492
493
494
495
496

    if current_platform.is_rocm():
        triton_scaled_mm_module = importlib.import_module(
            "vllm.model_executor.layers.quantization.compressed_tensors."
            "triton_scaled_mm")
        triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
        return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

497
498
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

499
500
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

501
502
503
    return out


504
505
506
507
508
509
510
511
def cutlass_scaled_mm_azp(a: torch.Tensor,
                          b: torch.Tensor,
                          scale_a: torch.Tensor,
                          scale_b: torch.Tensor,
                          out_dtype: torch.dtype,
                          azp_adj: torch.Tensor,
                          azp: Optional[torch.Tensor] = None,
                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
512
513
514
515
516
    """
    :param azp_adj: In the per-tensor case, this should include the azp.
    Always per-channel.
    :param azp: Only set in the per-token case. Per-token if set.
    """
517
518
519
520
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.numel(
    ) == b.shape[1] and bias.dtype == out_dtype
521
    assert azp is None or azp.numel() == a.shape[0]
522
523
524
525
526
527
528
529
530
531

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
                                       azp, bias)
    return out


532
533
534
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
535
              codebook_partition_sizes: List[int],
536
              bias: Optional[torch.Tensor]) -> torch.Tensor:
537
538
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
539
540
541


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
542
                 codebook_partition_sizes: List[int]) -> torch.Tensor:
543
544
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
545
546


547
548
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
549
550
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
551
552
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
553
554


555
556
557
558
559
560
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


561
562
563
564
565
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                           size_k: int, size_n: int,
                           num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
566
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
567
568
569
570
571
572
573
574
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
                                                    size_k, size_n, num_bits)
    return output


575
576
577
578
579
580
581
582
583
584
585
586
587
588
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                          size_k: int, size_n: int,
                          num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
                                                   size_n, num_bits)
    return output


589
590
591
592
593
594
595
596
597
598
599
600
601
def gptq_marlin_gemm(a: torch.Tensor,
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
                     b_zeros: torch.Tensor,
                     g_idx: torch.Tensor,
                     perm: torch.Tensor,
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
                     is_k_full: bool,
                     has_zp: bool = False,
602
603
                     use_fp32_reduce: bool = False,
                     is_zp_float: bool = False) -> torch.Tensor:
604
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
605
                                         g_idx, perm, workspace, b_q_type.id,
606
                                         size_m, size_n, size_k, is_k_full,
607
                                         has_zp, use_fp32_reduce, is_zp_float)
608
609


610
611
612
613
614
615
616
617
618
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


619
# machete
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
def machete_supported_schedules(
        a_type: torch.dtype,
        b_type: ScalarType,
        group_scales_type: Optional[torch.dtype],
        group_zeros_type: Optional[torch.dtype] = None,
        channel_scales_type: Optional[torch.dtype] = None,
        token_scales_type: Optional[torch.dtype] = None,
        out_type: Optional[torch.dtype] = None) -> List[str]:
    return torch.ops._C.machete_supported_schedules(
        a_type, b_type.id, group_scales_type, group_zeros_type,
        channel_scales_type, token_scales_type, out_type)


def machete_mm(
        a: torch.Tensor,
        # b_q Should be the tensor returned by machete_prepack_B
        b_q: torch.Tensor,
        b_type: ScalarType,
        out_type: Optional[torch.dtype] = None,
        b_group_scales: Optional[torch.Tensor] = None,
        b_group_zeros: Optional[torch.Tensor] = None,
        b_group_size: Optional[int] = None,
        b_channel_scales: Optional[torch.Tensor] = None,
        a_token_scales: Optional[torch.Tensor] = None,
        schedule: Optional[str] = None) -> torch.Tensor:
    return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
                                   b_group_zeros, b_group_size,
                                   b_channel_scales, a_token_scales, schedule)


def machete_prepack_B(
        b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
        group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
    return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id,
                                          group_scales_type)
655
656


657
if hasattr(torch.ops._C, "permute_cols"):
658

659
    @register_fake("_C::permute_cols")
660
661
662
663
664
665
666
667
668
    def _permute_cols_fake(a: torch.Tensor,
                           perm: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(a)


def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
    return torch.ops._C.permute_cols(a, perm)


669
# fp8
670
671
672
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
673
    num_token_padding: Optional[int] = None,
674
    scale_ub: Optional[torch.Tensor] = None,
675
    use_per_token_if_dynamic: bool = False,
676
) -> Tuple[torch.Tensor, torch.Tensor]:
677
678
679
680
681
682
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
683
    optional padding of the output tensors for downstream kernels that
684
685
686
687
688
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
689
        scale_ub: Optional upper bound for scaling factor in dynamic
690
            per token case
691
        num_token_padding: If specified, pad the first dimension
692
            of the output to at least this value.
693
        use_per_token_if_dynamic: Whether to do per_tensor or per_token
694
            in the dynamic quantization case.
695
696
697
698
699

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
700
701
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
702
    shape: Union[Tuple[int, int], torch.Size] = input.shape
703
    # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
704
705
    out_dtype: torch.dtype = torch.float8_e4m3fnuz \
            if current_platform.is_rocm() else torch.float8_e4m3fn
706
707
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
708
    output = torch.empty(shape, device=input.device, dtype=out_dtype)
709

710
    if scale is None:
711
        if use_per_token_if_dynamic:
712
            scale = torch.empty((shape[0], 1),
713
714
715
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
716
                output, input, scale, scale_ub)
717
718
719
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
720
    else:
721
722
        # num_token_padding not implemented for this case
        assert (scale.numel() == 1 or num_token_padding is None)
723
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
724

725
    return output, scale
726
727


728
# int8
729
def scaled_int8_quant(
730
731
732
733
734
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
    azp: Optional[torch.Tensor] = None,
    symmetric: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
735
    """
736
    Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
737
738
739

    Args:
        input: The input tensor to be quantized to int8.
740
741
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
742
743
744
        azp: Optional zero-point for the int8 quantization.
            Must be provided for asymmetric quantization if `scale` is provided.
        symmetric: Whether to use symmetric quantization (scale only, azp ignored).
745
746

    Returns:
747
      Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
748
    """
749
750
751
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
752
753
754
755
        assert symmetric == (
            azp is
            None), "azp must only be provided for asymmetric quantization."
        torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
756
        return output, scale, azp
757
758
759
760
761

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
762
763
764
765
766
    input_azp = None if symmetric else torch.empty_like(input_scales,
                                                        dtype=torch.int32)
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
                                           input_azp)
    return output, input_scales, input_azp
767
768


769
770
771
772
773
774
775
776
777
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


778
# gguf
779
780
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
                    n: int) -> torch.Tensor:
781
782
783
784
785
786
787
788
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
789
) -> torch.Tensor:
790
791
792
793
794
795
796
797
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
798
) -> torch.Tensor:
799
800
801
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


802
803
804
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                      bias_: Optional[torch.Tensor],
805
806
807
808
                      conv_states: Optional[torch.Tensor],
                      query_start_loc: Optional[torch.Tensor],
                      cache_indices: Optional[torch.Tensor],
                      has_initial_state: Optional[torch.Tensor],
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
                      silu_activation: bool, pad_slot_id: int):
    torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
                                   query_start_loc, cache_indices,
                                   has_initial_state, silu_activation,
                                   pad_slot_id)


def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
                         weight: torch.Tensor, bias_: Optional[torch.Tensor],
                         silu_activation: bool,
                         cache_seqlens: Optional[torch.Tensor],
                         conv_state_indices: Optional[torch.Tensor],
                         pad_slot_id: int):
    torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
                                      silu_activation, cache_seqlens,
                                      conv_state_indices, pad_slot_id)


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
                       B: torch.Tensor, C: torch.Tensor,
                       D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
                       delta_bias_: Optional[torch.Tensor],
                       delta_softplus: bool,
                       query_start_loc: Optional[torch.Tensor],
                       cache_indices: Optional[torch.Tensor],
                       has_initial_state: Optional[torch.Tensor],
                       ssm_states: torch.Tensor, pad_slot_id: int):
836
837
838
    torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
                                    delta_softplus, query_start_loc,
                                    cache_indices, has_initial_state,
839
                                    ssm_states, pad_slot_id)
840
841


842
# moe
843
844
845
846
def moe_sum(input: torch.Tensor, output: torch.Tensor):
    torch.ops._moe_C.moe_sum(input, output)


847
848
849
850
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
851
852
853
    torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
                                          sorted_token_ids, experts_ids,
                                          num_tokens_post_pad)
854
855
856
857
858
859
860


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
861
862


863
864
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

865
    @register_fake("_moe_C::marlin_gemm_moe")
866
867
868
869
    def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
                             sorted_ids: torch.Tensor,
                             topk_weights: torch.Tensor,
                             topk_ids: torch.Tensor, b_scales: torch.Tensor,
870
871
                             b_zero_points: torch.Tensor, g_idx: torch.Tensor,
                             perm: torch.Tensor, workspace: torch.Tensor,
872
873
874
875
                             b_q_type: ScalarType, size_m: torch.SymInt,
                             size_n: torch.SymInt, size_k: torch.SymInt,
                             is_k_full: bool, num_experts: int, topk: int,
                             moe_block_size: int, replicate_input: bool,
876
877
878
879
880
881
                             apply_weights: bool) -> torch.Tensor:
        return torch.empty((size_m, topk, size_n),
                           dtype=a.dtype,
                           device=a.device)


882
883
884
885
886
887
888
def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
889
890
    k_scale: float,
    v_scale: float,
891
) -> None:
892
893
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
894
                                             kv_cache_dtype, k_scale, v_scale)
895
896


897
898
899
900
901
902
903
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
904
905
    k_scale: float,
    v_scale: float,
906
) -> None:
907
908
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
909
910
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
911
912


913
914
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
915
                block_mapping: torch.Tensor) -> None:
916
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
917
918
919


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
920
                block_mapping: torch.Tensor) -> None:
921
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
922
923


924
925
926
927
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
928
929
930
931
932
933
934
935
936
937
938
939
940
941
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
942
943
944
945
def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor,
                   rank: int, full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
                                                 full_nvlink)
946
947


948
949
950
951
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
               reg_buffer_sz_bytes: int) -> None:
    torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
                                      reg_buffer_sz_bytes)
952

953
954
955
956
957
958
959
960
961

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


962
963
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
964
965


966
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
967
968
969
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


970
def register_graph_buffers(fa: int, handles: List[List[int]],
971
972
973
974
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
990
                for arg in v.__annotations__.values()):
991
992
993
994
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type