_custom_ops.py 24.5 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import List, Optional, Tuple, Union
4
5
6

import torch

7
import vllm.envs as envs
8
from vllm._core_ext import ScalarType
9
from vllm.logger import init_logger
10
from vllm.platforms import current_platform
11
12
13

logger = init_logger(__name__)

14
15
16
17
18
if not current_platform.is_tpu():
    try:
        import vllm._C
    except ImportError as e:
        logger.warning("Failed to import from vllm._C with %r", e)
19

20
with contextlib.suppress(ImportError):
21
    import vllm._moe_C  # noqa: F401
22

23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


43
44
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
45
    torch.ops._C.silu_and_mul(out, x)
46
47
48


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
49
    torch.ops._C.gelu_and_mul(out, x)
50
51
52


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
53
    torch.ops._C.gelu_tanh_and_mul(out, x)
54
55
56


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
57
    torch.ops._C.gelu_fast(out, x)
58
59
60


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
61
    torch.ops._C.gelu_new(out, x)
62
63


64
65
66
67
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


68
69
70
71
72
73
74
75
76
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
77
    seq_lens: torch.Tensor,
78
    block_size: int,
79
    max_seq_len: int,
80
81
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
82
83
    k_scale: float,
    v_scale: float,
84
85
86
87
88
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
89
) -> None:
90
    torch.ops._C.paged_attention_v1(
91
92
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
93
94
95
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
96
97
98
99
100
101
102
103
104
105
106
107
108


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
109
    seq_lens: torch.Tensor,
110
    block_size: int,
111
    max_seq_len: int,
112
113
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
114
115
    k_scale: float,
    v_scale: float,
116
117
118
119
120
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
121
) -> None:
122
    torch.ops._C.paged_attention_v2(
123
124
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
125
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
126
127
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
128
129
130
131
132
133
134
135
136
137
138


# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
139
140
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
141
142
143
144
145
146
147


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
148
149
150
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
151
152
153
154
155


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
156
    torch.ops._C.rms_norm(out, input, weight, epsilon)
157
158
159
160


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
161
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
162
163


164
165
166
167
168
169
170
171
172
173
174
175
def advance_step(num_seqs: int, num_queries: int, block_size: int,
                 input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
                 input_positions: torch.Tensor, seq_lens: torch.Tensor,
                 slot_mapping: torch.Tensor,
                 block_tables: torch.Tensor) -> None:
    """Advance a step on GPU for existing inputs for a multi-step runner"""
    return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
                                     input_tokens, sampled_token_ids,
                                     input_positions, seq_lens, slot_mapping,
                                     block_tables)


176
177
178
179
180
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
181
182
183
184
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_dequantize_triton)
        return awq_dequantize_triton(qweight, scales, zeros)
185
186
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
187
188
189
190


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
191
192
193
194
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import (
            awq_gemm_triton)
        return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
195
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
196
197
198
199
200
201
202


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
203
204
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
205
206
207
208


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
209
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
210
211
212
213
214
215


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
216
217
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
218
219


220
221
222
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
223
224
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
225
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
226
                                            workspace, b_q_type, size_m,
227
                                            size_n, size_k)
228
229


230
# cutlass
231
232
233
234
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


235
236
237
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
238
                      scale_b: torch.Tensor,
239
                      out_dtype: torch.dtype,
240
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
241
242
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
243
244
    assert bias is None or bias.shape[0] == b.shape[
        1] and bias.dtype == out_dtype
245
246
247
248
249

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

250
251
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

252
253
254
    return out


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def cutlass_scaled_mm_azp(a: torch.Tensor,
                          b: torch.Tensor,
                          scale_a: torch.Tensor,
                          scale_b: torch.Tensor,
                          out_dtype: torch.dtype,
                          azp_adj: torch.Tensor,
                          azp: Optional[torch.Tensor] = None,
                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.numel(
    ) == b.shape[1] and bias.dtype == out_dtype

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
                                       azp, bias)
    return out


277
278
279
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
280
              codebook_partition_sizes: List[int],
281
              bias: Optional[torch.Tensor]) -> torch.Tensor:
282
283
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
284
285
286


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
287
                 codebook_partition_sizes: List[int]) -> torch.Tensor:
288
289
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
290
291


292
293
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
294
295
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
296
297
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
298
299


300
301
302
303
304
305
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


306
307
308
309
310
311
312
313
314
315
316
317
318
319
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                           size_k: int, size_n: int,
                           num_bits: int) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
    output = torch.empty((num_experts, size_k // 16, size_n * 2),
                         device=b_q_weight.device,
                         dtype=b_q_weight.dtype)
    for e in range(num_experts):
        output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
                                                    size_k, size_n, num_bits)
    return output


320
321
322
323
324
325
326
327
328
329
330
331
332
333
def gptq_marlin_gemm(a: torch.Tensor,
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
                     b_zeros: torch.Tensor,
                     g_idx: torch.Tensor,
                     perm: torch.Tensor,
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
                     is_k_full: bool,
                     has_zp: bool = False,
                     use_fp32_reduce: bool = False) -> torch.Tensor:
334
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
335
                                         g_idx, perm, workspace, b_q_type,
336
                                         size_m, size_n, size_k, is_k_full,
337
                                         has_zp, use_fp32_reduce)
338
339


340
341
342
343
344
345
346
347
348
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
    return torch.ops._C.machete_supported_schedules(b_type)


def machete_gemm(
    a: torch.Tensor,
    b_q: torch.Tensor,  # Should be the tensor returned by machete_prepack_B
    b_type: ScalarType,
    b_scales: Optional[torch.Tensor] = None,
    b_zeros: Optional[torch.Tensor] = None,
    b_group_size: Optional[int] = None,
    c: Optional[torch.Tensor] = None,
    alpha: Optional[float] = None,
    beta: Optional[float] = None,
    schedule: Optional[str] = None,
) -> torch.Tensor:
    return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
                                     b_group_size, c, alpha, beta, schedule)


def machete_prepack_B(b_q_weight: torch.Tensor,
                      b_type: ScalarType) -> torch.Tensor:
    return torch.ops._C.machete_prepack_B(b_q_weight, b_type)


375
# fp8
376
377
378
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
379
    num_token_padding: Optional[int] = None,
380
    scale_ub: Optional[torch.Tensor] = None,
381
    use_per_token_if_dynamic: bool = False,
382
) -> Tuple[torch.Tensor, torch.Tensor]:
383
384
385
386
387
388
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
389
    optional padding of the output tensors for downstream kernels that
390
391
392
393
394
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
395
396
        scale_ub: Optional upper bound for scaling factor in dynamic 
            per token case
397
        num_token_padding: If specified, pad the first dimension
398
            of the output to at least this value.
399
400
        use_per_token_if_dynamic: Whether to do per_tensor or per_token 
            in the dynamic quantization case.
401
402
403
404
405

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
406
407
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
408
    shape: Union[Tuple[int, int], torch.Size] = input.shape
409
410
411
    # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
    out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
        else torch.float8_e4m3fn
412
413
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
414
    output = torch.empty(shape, device=input.device, dtype=out_dtype)
415

416
    if scale is None:
417
        if use_per_token_if_dynamic:
418
            scale = torch.empty((shape[0], 1),
419
420
421
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
422
                output, input, scale, scale_ub)
423
424
425
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
426
    else:
427
428
        # num_token_padding not implemented for this case
        assert (scale.numel() == 1 or num_token_padding is None)
429
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
430

431
    return output, scale
432
433


434
# int8
435
436
437
438
def scaled_int8_quant(
        input: torch.Tensor,
        scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
439
    """
440
    Quantize the input tensor to int8 and return the quantized tensor and scale.
441
442
443

    Args:
        input: The input tensor to be quantized to int8.
444
445
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
446
447

    Returns:
448
      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
449
    """
450
451
452
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
453
        torch.ops._C.static_scaled_int8_quant(output, input, scale)
454
455
456
457
458
459
        return output, scale

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
460
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
461
    return output, input_scales
462
463


464
465
466
467
468
469
470
471
472
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


473
# gguf
474
475
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
                    n: int) -> torch.Tensor:
476
477
478
479
480
481
482
483
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
484
) -> torch.Tensor:
485
486
487
488
489
490
491
492
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
493
) -> torch.Tensor:
494
495
496
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                      bias_: Optional[torch.Tensor],
                      seq_idx_: Optional[torch.Tensor],
                      initial_states_: Optional[torch.Tensor],
                      final_states_out_: Optional[torch.Tensor],
                      silu_activation: bool) -> torch.Tensor:
    return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
                                          initial_states_, final_states_out_,
                                          silu_activation)


def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
                         weight: torch.Tensor, bias_: Optional[torch.Tensor],
                         silu_activation: bool) -> torch.Tensor:
    return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
                                             silu_activation)


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
                       B: torch.Tensor, C: torch.Tensor,
                       D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
                       delta_bias_: Optional[torch.Tensor],
                       delta_softplus: bool, index_: Optional[torch.Tensor],
                       x: Optional[torch.Tensor]) -> List[torch.Tensor]:
    return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
                                           delta_bias_, delta_softplus, index_,
                                           x)


527
528
529
530
531
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
532
533
534
535
536
537
538
539
540
541
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
542
543
544
545
546
547
548
549
550


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
551
552
    k_scale: float,
    v_scale: float,
553
) -> None:
554
555
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
556
                                             kv_cache_dtype, k_scale, v_scale)
557
558


559
560
561
562
563
564
565
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
566
567
    k_scale: float,
    v_scale: float,
568
) -> None:
569
570
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
571
572
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
573
574


575
576
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
577
                block_mapping: torch.Tensor) -> None:
578
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
579
580
581


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
582
                block_mapping: torch.Tensor) -> None:
583
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
584
585


586
587
588
589
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
                     full_nvlink: bool) -> bool:
    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
                                                   full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

620

621
622
623
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
624

625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
663
                for arg in v.__annotations__.values()):
664
665
666
667
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type