_custom_ops.py 21.3 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import List, Optional, Tuple, Type
4
import torch
gaoqiong's avatar
gaoqiong committed
5

6
try:
gaoqiong's avatar
gaoqiong committed
7
    from lmslim import quant_ops 
8
except Exception:
gaoqiong's avatar
gaoqiong committed
9
    print("INFO: Please install lmslim if you want to infer gptq or awq model.\n") 
10

11
12
13
14
from vllm.logger import init_logger

logger = init_logger(__name__)

15
try:
16
    import vllm._C
17
18
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)
19

20
21
22
23
24
25
26
27
28
29
30
with contextlib.suppress(ImportError):
    import vllm._moe_C

with contextlib.suppress(ImportError):
    # ruff: noqa: F401
    import vllm._punica_C


def is_custom_op_supported(op_name: str) -> bool:
    op, overloads = torch._C._jit_get_operation(op_name)
    return op is not None
31
32


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


52
53
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
54
    torch.ops._C.silu_and_mul(out, x)
55
56
57


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
58
    torch.ops._C.gelu_and_mul(out, x)
59
60
61


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
62
    torch.ops._C.gelu_tanh_and_mul(out, x)
63
64
65


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
66
    torch.ops._C.gelu_fast(out, x)
67
68
69


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
70
    torch.ops._C.gelu_new(out, x)
71
72


73
74
75
76
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


77
78
79
80
81
82
83
84
85
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
86
    seq_lens: torch.Tensor,
87
    block_size: int,
88
    max_seq_len: int,
89
90
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
91
92
    k_scale: float,
    v_scale: float,
93
94
95
96
97
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
98
) -> None:
99
    torch.ops._C.paged_attention_v1(
100
101
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
102
103
104
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
105
106
107
108
109
110
111
112
113
114
115
116
117


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
118
    seq_lens: torch.Tensor,
119
    block_size: int,
120
    max_seq_len: int,
121
122
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
123
124
    k_scale: float,
    v_scale: float,
125
126
127
128
129
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
130
) -> None:
131
    torch.ops._C.paged_attention_v2(
132
133
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
134
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
135
136
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
137
138
139
140
141
142
143
144
145
146
147


# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
148
149
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
150
151
152
153
154
155
156


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
157
158
159
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
160
161
162
163
164


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
165
    torch.ops._C.rms_norm(out, input, weight, epsilon)
166
167
168
169


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
170
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
171
172


173
174
175
176
177
178
179
180
181
182
183
184
def advance_step(num_seqs: int, num_queries: int, block_size: int,
                 input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
                 input_positions: torch.Tensor, seq_lens: torch.Tensor,
                 slot_mapping: torch.Tensor,
                 block_tables: torch.Tensor) -> None:
    """Advance a step on GPU for existing inputs for a multi-step runner"""
    return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
                                     input_tokens, sampled_token_ids,
                                     input_positions, seq_lens, slot_mapping,
                                     block_tables)


185
186
187
188
189
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
190
191
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
192
193


gaoqiong's avatar
gaoqiong committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
#              scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
#     return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)

def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
             zeros_and_scales:torch.Tensor,
             m:int,n:int,k:int,
             group_size:int,padding_group:int,splikspace:torch.Tensor,
            splikspacesize:int) -> torch.Tensor:
    return quant_ops.awq_gemm(input,
                              weight,
                              zeros_and_scales,
                              m,
                              n,
                              k,
                              group_size,
                              padding_group,
                              splikspace,
                              splikspacesize)

def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
               group_size: int):
    return quant_ops.convert_s4(qw,qz,s,group_size)

def sz_permute(sz:torch.Tensor)-> torch.Tensor:
    return quant_ops.sz_permute(sz)

def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
                                zeros_and_scale:torch.Tensor,
                                k:int,
                                n:int,
                                group_size:int
                             )->torch.Tensor:
    return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
228
229
230
231
232
233

# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
gaoqiong's avatar
gaoqiong committed
234
    return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
235
                                  b_g_idx, use_exllama, bit)
236
237
    # return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
    #                               b_g_idx, use_exllama, bit)
238
239
240
241


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
gaoqiong's avatar
gaoqiong committed
242
    quant_ops.gptq_shuffle(q_weight, q_perm, bit)
243
    # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
244

gaoqiong's avatar
gaoqiong committed
245
246
247
248
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
                row:int, col:int) -> None :
    torch.ops._C.trans_w16_gemm(dst,src,row,col)
249
250
251
252

# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
                    lookup_table: torch.Tensor) -> None:
253
    torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
254
255
256
257
258
259


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
260
261
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
262
263


264
265
266
267
268
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
                        workspace: torch.Tensor, num_bits: int, size_m: int,
                        size_n: int, size_k: int) -> torch.Tensor:
269
270
271
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
                                            workspace, num_bits, size_m,
                                            size_n, size_k)
272
273


274
# cutlass
275
276
277
278
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


279
280
281
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
282
                      scale_b: torch.Tensor,
283
284
                      out_dtype: Type[torch.dtype],
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
285
286
287
288
289
290
291
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

292
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
293
294
295
296

    return out


297
298
299
300
301
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
              codebook_partition_sizes: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
302
303
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
304
305
306
307


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
                 codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
308
309
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
310
311


312
313
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
314
315
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
316
317
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
318
319


320
321
322
323
324
325
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


326
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
327
328
329
330
331
332
333
334
335
                     b_scales: torch.Tensor, b_zeros: torch.Tensor,
                     g_idx: torch.Tensor, perm: torch.Tensor,
                     workspace: torch.Tensor, num_bits: int, size_m: int,
                     size_n: int, size_k: int, is_k_full: bool,
                     has_zp: bool) -> torch.Tensor:
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
                                         g_idx, perm, workspace, num_bits,
                                         size_m, size_n, size_k, is_k_full,
                                         has_zp)
336
337


338
339
340
341
342
343
344
345
346
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


347
# fp8
zhuwenwen's avatar
zhuwenwen committed
348
349
350
# def scaled_fp8_quant(
#     input: torch.Tensor,
#     scale: Optional[torch.Tensor] = None,
zhuwenwen's avatar
zhuwenwen committed
351
#     batch_dim_padding: Optional[int] = None,
352
353
#     scale_ub: Optional[torch.Tensor] = None,
#     use_per_token_if_dynamic: bool = False,
zhuwenwen's avatar
zhuwenwen committed
354
# ) -> Tuple[torch.Tensor, torch.Tensor]:
zhuwenwen's avatar
zhuwenwen committed
355
356
357
358
359
360
361
362
363
364
365
366
#     """
#     Quantize input tensor to FP8 and return quantized tensor and scale.

#     This function supports both static and dynamic quantization: If you
#     provide the scale, it will use static scaling and if you omit it,
#     the scale will be determined dynamically. The function also allows
#     optional padding of the output tensor for downstream kernels that
#     will benefit from padding.

#     Args:
#         input: The input tensor to be quantized to FP8
#         scale: Optional scaling factor for the FP8 quantization
367
368
#         scale_ub: Optional upper bound for scaling factor in dynamic 
#             per token case
zhuwenwen's avatar
zhuwenwen committed
369
370
#         batch_dim_padding: If specified, pad the first dimension
#             of the output to at least this value.
371
372
#         use_per_token_if_dynamic: Whether to do per_tensor or per_token 
#             in the dynamic quantization case.
zhuwenwen's avatar
zhuwenwen committed
373
374
375
376
377
378
379
380
381
382
383
384

#     Returns:
#         Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
#             scaling factor.
#     """
#     if batch_dim_padding:
#         shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
#         output = torch.empty(shape,
#                              device=input.device,
#                              dtype=torch.float8_e4m3fn)
#     else:
#         output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
zhuwenwen's avatar
zhuwenwen committed
385
#     if scale is None:
386
387
388
389
390
391
392
393
394
#         if use_per_token_if_dynamic:
#             scale = torch.empty((input.numel() // input.shape[-1], 1),
#                                 device=input.device,
#                                 dtype=torch.float32)
#             torch.ops._C.dynamic_per_token_scaled_fp8_quant(
#                 output, input, scale, scale_ub)
#         else:
#             scale = torch.zeros(1, device=input.device, dtype=torch.float32)
#             torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
zhuwenwen's avatar
zhuwenwen committed
395
#     else:
zhuwenwen's avatar
zhuwenwen committed
396
#         torch.ops._C.static_scaled_fp8_quant(output, input, scale)
397

zhuwenwen's avatar
zhuwenwen committed
398
#     return output, scale
399
400


401
# int8
402
403
404
405
def scaled_int8_quant(
        input: torch.Tensor,
        scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
zhuwenwen's avatar
zhuwenwen committed
406
    """
407
    Quantize the input tensor to int8 and return the quantized tensor and scale.
zhuwenwen's avatar
zhuwenwen committed
408
409
410

    Args:
        input: The input tensor to be quantized to int8.
411
412
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
zhuwenwen's avatar
zhuwenwen committed
413
414

    Returns:
415
      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
zhuwenwen's avatar
zhuwenwen committed
416
    """
417
418
419
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
420
        torch.ops._C.static_scaled_int8_quant(output, input, scale)
421
422
423
424
425
426
        return output, scale

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
427
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
428
    return output, input_scales
429
430


431
432
433
434
435
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
436
437
438
439
440
441
442
443
444
445
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
446
447
448
449
450
451
452
453
454


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
455
456
    k_scale: float,
    v_scale: float,
457
) -> None:
458
459
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
460
                                             kv_cache_dtype, k_scale, v_scale)
461
462


463
464
465
466
467
468
469
470
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
) -> None:
471
472
473
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
                                                   kv_cache_dtype)
474
475


476
477
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
478
                block_mapping: torch.Tensor) -> None:
479
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
480
481
482


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
483
                block_mapping: torch.Tensor) -> None:
484
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
485
486


487
488
489
490
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
                     full_nvlink: bool) -> bool:
    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
                                                   full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

521

522
523
524
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
525

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


# punica
def dispatch_bgmv(
    y: torch.Tensor,
    x: torch.Tensor,
    w_t_all: torch.Tensor,
    indicies: torch.Tensor,
    layer_idx: int,
    scale: float,
) -> None:
    torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
                                      scale)


def dispatch_bgmv_low_level(
    y: torch.Tensor,
    x: torch.Tensor,
    w_t_all: torch.Tensor,
    indicies: torch.Tensor,
    layer_idx: int,
    scale: float,
    h_in: int,
    h_out: int,
    y_offset: int,
) -> None:
    torch.ops._punica_C.dispatch_bgmv_low_level(
        y,
        x,
        w_t_all,
        indicies,
        layer_idx,
        scale,
        h_in,
        h_out,
        y_offset,
    )
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605


# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
                   for arg in v.__annotations__.values()):
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type