_custom_ops.py 26.8 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import List, Optional, Tuple, Union
4
5

import torch
gaoqiong's avatar
gaoqiong committed
6

7
from vllm._core_ext import ScalarType
8
from vllm.logger import init_logger
9
from vllm.platforms import current_platform
10

11
try:
gaoqiong's avatar
gaoqiong committed
12
    from lmslim import quant_ops 
13
except Exception:
gaoqiong's avatar
gaoqiong committed
14
    print("INFO: Please install lmslim if you want to infer gptq or awq model.\n") 
15

16
17
logger = init_logger(__name__)

18
19
20
21
22
if not current_platform.is_tpu():
    try:
        import vllm._C
    except ImportError as e:
        logger.warning("Failed to import from vllm._C with %r", e)
23

24
with contextlib.suppress(ImportError):
25
    import vllm._moe_C  # noqa: F401
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


47
48
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
49
    torch.ops._C.silu_and_mul(out, x)
50
51
52


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
53
    torch.ops._C.gelu_and_mul(out, x)
54
55
56


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
57
    torch.ops._C.gelu_tanh_and_mul(out, x)
zhuwenwen's avatar
zhuwenwen committed
58
59
60
61
62
63
64
65
66
67
68
69
    
    
def silu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.silu_and_mul_opt(out, x)


def gelu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_and_mul_opt(out, x)


def gelu_tanh_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_tanh_and_mul_opt(out, x)
70
71
72


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
73
    torch.ops._C.gelu_fast(out, x)
74
75
76


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
77
    torch.ops._C.gelu_new(out, x)
78
79


80
81
82
83
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


84
85
86
87
88
89
90
91
92
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
93
    seq_lens: torch.Tensor,
94
    block_size: int,
95
    max_seq_len: int,
96
97
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
98
99
    k_scale: float,
    v_scale: float,
100
101
102
103
104
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
105
) -> None:
106
    torch.ops._C.paged_attention_v1(
107
108
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
109
110
111
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
112
113
114
115
116
117
118
119
120
121
122
123
124


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
125
    seq_lens: torch.Tensor,
126
    block_size: int,
127
    max_seq_len: int,
128
129
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
130
131
    k_scale: float,
    v_scale: float,
132
133
134
135
136
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
137
) -> None:
138
    torch.ops._C.paged_attention_v2(
139
140
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
141
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
142
143
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
144
145


zhuwenwen's avatar
zhuwenwen committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# page attention ops (opt)
def paged_attention_v1_opt(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
    block_size: int,
    max_seq_len: int,
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
    k_scale: float,
    v_scale: float,
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
) -> None:
    torch.ops._C.paged_attention_v1_opt(
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)


def paged_attention_v2_opt(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
    block_size: int,
    max_seq_len: int,
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
    k_scale: float,
    v_scale: float,
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
) -> None:
    torch.ops._C.paged_attention_v2_opt(
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
    
    
208
209
210
211
212
213
214
215
216
# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
217
218
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
219
220
221
222
223
224
225


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
226
227
228
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
229
230
231
232
233


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
234
    torch.ops._C.rms_norm(out, input, weight, epsilon)
235
236
237
238


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
239
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
zhuwenwen's avatar
zhuwenwen committed
240
241
242
243
244
245
246
247
248
249
250
    

# layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
    torch.ops._C.rms_norm_opt(out, input, weight, epsilon)


def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
    torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
251
252


253
254
255
256
257
258
259
260
261
262
263
def advance_step(num_seqs: int, num_queries: int, block_size: int,
                 input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
                 input_positions: torch.Tensor, seq_lens: torch.Tensor,
                 slot_mapping: torch.Tensor,
                 block_tables: torch.Tensor) -> None:
    """Advance a step on GPU for existing inputs for a multi-step runner"""
    return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
                                     input_tokens, sampled_token_ids,
                                     input_positions, seq_lens, slot_mapping,
                                     block_tables)

zhuwenwen's avatar
zhuwenwen committed
264
265
266
267
268
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
                row:int, col:int) -> None :
    torch.ops._C.trans_w16_gemm(dst,src,row,col)
    
269

270
271
# quantization ops
# awq
zhuwenwen's avatar
zhuwenwen committed
272
273
274
275
276
277
def GetAWQShareWorkspaceSize()->int:
    return quant_ops.GetAWQShareWorkspaceSize()

def GetAWQShareWorkspace()->torch.Tensor:
    return quant_ops.GetAWQShareWorkspace()

278
279
280
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
281
282
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
283
284


gaoqiong's avatar
gaoqiong committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
#              scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
#     return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)

def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
             zeros_and_scales:torch.Tensor,
             m:int,n:int,k:int,
             group_size:int,padding_group:int,splikspace:torch.Tensor,
            splikspacesize:int) -> torch.Tensor:
    return quant_ops.awq_gemm(input,
                              weight,
                              zeros_and_scales,
                              m,
                              n,
                              k,
                              group_size,
                              padding_group,
                              splikspace,
                              splikspacesize)

def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
               group_size: int):
    return quant_ops.convert_s4(qw,qz,s,group_size)

def sz_permute(sz:torch.Tensor)-> torch.Tensor:
    return quant_ops.sz_permute(sz)

def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
                                zeros_and_scale:torch.Tensor,
                                k:int,
                                n:int,
                                group_size:int
                             )->torch.Tensor:
    return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
319
320
321
322
323
324

# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
gaoqiong's avatar
gaoqiong committed
325
    return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
326
                                  b_g_idx, use_exllama, bit)
327
328
    # return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
    #                               b_g_idx, use_exllama, bit)
329
330
331
332


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
gaoqiong's avatar
gaoqiong committed
333
    quant_ops.gptq_shuffle(q_weight, q_perm, bit)
334
    # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
335
336
337
338
339


# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
                    lookup_table: torch.Tensor) -> None:
340
    torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
341
342
343
344
345
346


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
347
348
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
349
350


351
352
353
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
354
355
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
356
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
357
                                            workspace, b_q_type, size_m,
358
                                            size_n, size_k)
359
360


361
# cutlass
362
363
364
365
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


366
367
368
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
369
                      scale_b: torch.Tensor,
370
                      out_dtype: torch.dtype,
371
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
372
373
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
374
375
    assert bias is None or bias.shape[0] == b.shape[
        1] and bias.dtype == out_dtype
376
377
378
379
380

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

381
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
382
383
384
385

    return out


386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def cutlass_scaled_mm_azp(a: torch.Tensor,
                          b: torch.Tensor,
                          scale_a: torch.Tensor,
                          scale_b: torch.Tensor,
                          out_dtype: torch.dtype,
                          azp_adj: torch.Tensor,
                          azp: Optional[torch.Tensor] = None,
                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.numel(
    ) == b.shape[1] and bias.dtype == out_dtype

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
                                       azp, bias)
    return out


408
409
410
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
411
              codebook_partition_sizes: List[int],
412
              bias: Optional[torch.Tensor]) -> torch.Tensor:
413
414
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
415
416
417


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
418
                 codebook_partition_sizes: List[int]) -> torch.Tensor:
419
420
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
421
422


423
424
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
425
426
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
427
428
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
429
430


431
432
433
434
435
436
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


437
438
439
440
441
442
443
444
445
446
447
448
449
450
def gptq_marlin_gemm(a: torch.Tensor,
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
                     b_zeros: torch.Tensor,
                     g_idx: torch.Tensor,
                     perm: torch.Tensor,
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
                     is_k_full: bool,
                     has_zp: bool = False,
                     use_fp32_reduce: bool = False) -> torch.Tensor:
451
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
452
                                         g_idx, perm, workspace, b_q_type,
453
                                         size_m, size_n, size_k, is_k_full,
454
                                         has_zp, use_fp32_reduce)
455
456


457
458
459
460
461
462
463
464
465
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
    return torch.ops._C.machete_supported_schedules(b_type)


def machete_gemm(
    a: torch.Tensor,
    b_q: torch.Tensor,  # Should be the tensor returned by machete_prepack_B
    b_type: ScalarType,
    b_scales: Optional[torch.Tensor] = None,
    b_zeros: Optional[torch.Tensor] = None,
    b_group_size: Optional[int] = None,
    c: Optional[torch.Tensor] = None,
    alpha: Optional[float] = None,
    beta: Optional[float] = None,
    schedule: Optional[str] = None,
) -> torch.Tensor:
    return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
                                     b_group_size, c, alpha, beta, schedule)


def machete_prepack_B(b_q_weight: torch.Tensor,
                      b_type: ScalarType) -> torch.Tensor:
    return torch.ops._C.machete_prepack_B(b_q_weight, b_type)


492
# fp8
zhuwenwen's avatar
zhuwenwen committed
493
494
495
# def scaled_fp8_quant(
#     input: torch.Tensor,
#     scale: Optional[torch.Tensor] = None,
496
#     num_token_padding: Optional[int] = None,
497
498
#     scale_ub: Optional[torch.Tensor] = None,
#     use_per_token_if_dynamic: bool = False,
zhuwenwen's avatar
zhuwenwen committed
499
# ) -> Tuple[torch.Tensor, torch.Tensor]:
zhuwenwen's avatar
zhuwenwen committed
500
501
502
503
504
505
#     """
#     Quantize input tensor to FP8 and return quantized tensor and scale.

#     This function supports both static and dynamic quantization: If you
#     provide the scale, it will use static scaling and if you omit it,
#     the scale will be determined dynamically. The function also allows
506
#     optional padding of the output tensors for downstream kernels that
zhuwenwen's avatar
zhuwenwen committed
507
508
509
510
511
#     will benefit from padding.

#     Args:
#         input: The input tensor to be quantized to FP8
#         scale: Optional scaling factor for the FP8 quantization
512
513
#         scale_ub: Optional upper bound for scaling factor in dynamic 
#             per token case
514
#         num_token_padding: If specified, pad the first dimension
zhuwenwen's avatar
zhuwenwen committed
515
#             of the output to at least this value.
516
517
#         use_per_token_if_dynamic: Whether to do per_tensor or per_token 
#             in the dynamic quantization case.
zhuwenwen's avatar
zhuwenwen committed
518
519
520
521
522

#     Returns:
#         Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
#             scaling factor.
#     """
523
524
525
#     # This code assumes batch_dim and num_tokens are flattened
#     assert (input.ndim == 2)
#     shape: Union[Tuple[int, int], torch.Size] = input.shape
526
527
528
#     # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
#     out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
#         else torch.float8_e4m3fn
529
530
#     if num_token_padding:
#         shape = (max(num_token_padding, input.shape[0]), shape[1])
531
#     output = torch.empty(shape, device=input.device, dtype=out_dtype)
532

zhuwenwen's avatar
zhuwenwen committed
533
#     if scale is None:
534
#         if use_per_token_if_dynamic:
535
#             scale = torch.empty((shape[0], 1),
536
537
538
539
540
541
542
#                                 device=input.device,
#                                 dtype=torch.float32)
#             torch.ops._C.dynamic_per_token_scaled_fp8_quant(
#                 output, input, scale, scale_ub)
#         else:
#             scale = torch.zeros(1, device=input.device, dtype=torch.float32)
#             torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
zhuwenwen's avatar
zhuwenwen committed
543
#     else:
544
545
#         # num_token_padding not implemented for this case
#         assert (scale.numel() == 1 or num_token_padding is None)
zhuwenwen's avatar
zhuwenwen committed
546
#         torch.ops._C.static_scaled_fp8_quant(output, input, scale)
547

zhuwenwen's avatar
zhuwenwen committed
548
#     return output, scale
549
550


551
# int8
552
553
554
555
def scaled_int8_quant(
        input: torch.Tensor,
        scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
zhuwenwen's avatar
zhuwenwen committed
556
    """
557
    Quantize the input tensor to int8 and return the quantized tensor and scale.
zhuwenwen's avatar
zhuwenwen committed
558
559
560

    Args:
        input: The input tensor to be quantized to int8.
561
562
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
zhuwenwen's avatar
zhuwenwen committed
563
564

    Returns:
565
      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
zhuwenwen's avatar
zhuwenwen committed
566
    """
567
568
569
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
570
        torch.ops._C.static_scaled_int8_quant(output, input, scale)
571
572
573
574
575
576
        return output, scale

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
577
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
578
    return output, input_scales
579
580


581
582
583
584
585
586
587
588
589
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


590
# gguf
591
592
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
                    n: int) -> torch.Tensor:
593
594
595
596
597
598
599
600
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
601
) -> torch.Tensor:
602
603
604
605
606
607
608
609
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
610
) -> torch.Tensor:
611
612
613
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


614
615
616
617
618
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
619
620
621
622
623
624
625
626
627
628
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
629
630
631
632
633
634
635
636
637


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
638
639
    k_scale: float,
    v_scale: float,
640
) -> None:
641
642
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
643
                                             kv_cache_dtype, k_scale, v_scale)
644
645


646
647
648
649
650
651
652
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
653
654
    k_scale: float,
    v_scale: float,
655
) -> None:
656
657
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
658
659
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
660
661


662
663
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
664
                block_mapping: torch.Tensor) -> None:
665
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
666
667
668


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
669
                block_mapping: torch.Tensor) -> None:
670
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
671
672


673
674
675
676
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
                     full_nvlink: bool) -> bool:
    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
                                                   full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

707

708
709
710
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
711

712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
750
                for arg in v.__annotations__.values()):
751
752
753
754
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type