_custom_ops.py 20.3 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import List, Optional, Tuple, Union
4
5
6

import torch

7
from vllm._core_ext import ScalarType
8
9
10
11
from vllm.logger import init_logger

logger = init_logger(__name__)

12
try:
13
    import vllm._C
14
15
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)
16

17
18
with contextlib.suppress(ImportError):
    # ruff: noqa: F401
19
    import vllm._moe_C
20
21
22
23
24
25


def is_custom_op_supported(op_name: str) -> bool:
    op, overloads = torch._C._jit_get_operation(op_name)
    return op is not None

26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


46
47
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
48
    torch.ops._C.silu_and_mul(out, x)
49
50
51


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
52
    torch.ops._C.gelu_and_mul(out, x)
53
54
55


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
56
    torch.ops._C.gelu_tanh_and_mul(out, x)
57
58
59


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
60
    torch.ops._C.gelu_fast(out, x)
61
62
63


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
64
    torch.ops._C.gelu_new(out, x)
65
66


67
68
69
70
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


71
72
73
74
75
76
77
78
79
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
80
    seq_lens: torch.Tensor,
81
    block_size: int,
82
    max_seq_len: int,
83
84
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
85
86
    k_scale: float,
    v_scale: float,
87
88
89
90
91
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
92
) -> None:
93
    torch.ops._C.paged_attention_v1(
94
95
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
96
97
98
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
99
100
101
102
103
104
105
106
107
108
109
110
111


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
112
    seq_lens: torch.Tensor,
113
    block_size: int,
114
    max_seq_len: int,
115
116
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
117
118
    k_scale: float,
    v_scale: float,
119
120
121
122
123
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
124
) -> None:
125
    torch.ops._C.paged_attention_v2(
126
127
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
128
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
129
130
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
131
132
133
134
135
136
137
138
139
140
141


# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
142
143
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
144
145
146
147
148
149
150


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
151
152
153
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
154
155
156
157
158


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
159
    torch.ops._C.rms_norm(out, input, weight, epsilon)
160
161
162
163


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
164
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
165
166


167
168
169
170
171
172
173
174
175
176
177
178
def advance_step(num_seqs: int, num_queries: int, block_size: int,
                 input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
                 input_positions: torch.Tensor, seq_lens: torch.Tensor,
                 slot_mapping: torch.Tensor,
                 block_tables: torch.Tensor) -> None:
    """Advance a step on GPU for existing inputs for a multi-step runner"""
    return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
                                     input_tokens, sampled_token_ids,
                                     input_positions, seq_lens, slot_mapping,
                                     block_tables)


179
180
181
182
183
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
184
185
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
186
187
188
189


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
190
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
191
192
193
194
195
196
197


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
198
199
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
200
201
202
203


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
204
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
205
206
207
208
209


# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
                    lookup_table: torch.Tensor) -> None:
210
    torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
211
212
213
214
215
216


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
217
218
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
219
220


221
222
223
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
224
225
                        workspace: torch.Tensor, b_q_type: ScalarType,
                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
226
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
227
                                            workspace, b_q_type, size_m,
228
                                            size_n, size_k)
229
230


231
# cutlass
232
233
234
235
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


236
237
238
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
239
                      scale_b: torch.Tensor,
240
                      out_dtype: torch.dtype,
241
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
242
243
244
245
246
247
248
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

249
250
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

251
252
253
    return out


254
255
256
257
258
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
              codebook_partition_sizes: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
259
260
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
261
262
263
264


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
                 codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
265
266
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
267
268


269
270
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
271
272
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
273
274
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
275
276


277
278
279
280
281
282
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


283
284
285
286
287
288
289
290
291
292
293
294
295
296
def gptq_marlin_gemm(a: torch.Tensor,
                     b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor,
                     b_zeros: torch.Tensor,
                     g_idx: torch.Tensor,
                     perm: torch.Tensor,
                     workspace: torch.Tensor,
                     b_q_type: ScalarType,
                     size_m: int,
                     size_n: int,
                     size_k: int,
                     is_k_full: bool,
                     has_zp: bool = False,
                     use_fp32_reduce: bool = False) -> torch.Tensor:
297
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
298
                                         g_idx, perm, workspace, b_q_type,
299
                                         size_m, size_n, size_k, is_k_full,
300
                                         has_zp, use_fp32_reduce)
301
302


303
304
305
306
307
308
309
310
311
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


312
# fp8
313
314
315
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
316
    num_token_padding: Optional[int] = None,
317
    scale_ub: Optional[torch.Tensor] = None,
318
    use_per_token_if_dynamic: bool = False,
319
) -> Tuple[torch.Tensor, torch.Tensor]:
320
321
322
323
324
325
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
326
    optional padding of the output tensors for downstream kernels that
327
328
329
330
331
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
332
333
        scale_ub: Optional upper bound for scaling factor in dynamic 
            per token case
334
        num_token_padding: If specified, pad the first dimension
335
            of the output to at least this value.
336
337
        use_per_token_if_dynamic: Whether to do per_tensor or per_token 
            in the dynamic quantization case.
338
339
340
341
342

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
343
344
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
345
    shape: Union[Tuple[int, int], torch.Size] = input.shape
346
347
348
349
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
    output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)

350
    if scale is None:
351
        if use_per_token_if_dynamic:
352
            scale = torch.empty((shape[0], 1),
353
354
355
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
356
                output, input, scale, scale_ub)
357
358
359
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
360
    else:
361
362
        # num_token_padding not implemented for this case
        assert (scale.numel() == 1 or num_token_padding is None)
363
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
364

365
    return output, scale
366
367


368
# int8
369
370
371
372
def scaled_int8_quant(
        input: torch.Tensor,
        scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
373
    """
374
    Quantize the input tensor to int8 and return the quantized tensor and scale.
375
376
377

    Args:
        input: The input tensor to be quantized to int8.
378
379
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
380
381

    Returns:
382
      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
383
    """
384
385
386
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
387
        torch.ops._C.static_scaled_int8_quant(output, input, scale)
388
389
390
391
392
393
        return output, scale

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
394
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
395
    return output, input_scales
396
397


398
399
400
401
402
403
404
405
406
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
):
    return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)


def ggml_mul_mat_vec_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
):
    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
    W: torch.Tensor,
    X: torch.Tensor,
    quant_type: int,
    row: int,
):
    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


439
440
441
442
443
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
444
445
446
447
448
449
450
451
452
453
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
454
455
456
457
458
459
460
461
462


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
463
464
    k_scale: float,
    v_scale: float,
465
) -> None:
466
467
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
468
                                             kv_cache_dtype, k_scale, v_scale)
469
470


471
472
473
474
475
476
477
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
478
479
    k_scale: float,
    v_scale: float,
480
) -> None:
481
482
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
483
484
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
485
486


487
488
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
489
                block_mapping: torch.Tensor) -> None:
490
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
491
492
493


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
494
                block_mapping: torch.Tensor) -> None:
495
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
496
497


498
499
500
501
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
                     full_nvlink: bool) -> bool:
    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
                                                   full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

532

533
534
535
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
536

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
                   for arg in v.__annotations__.values()):
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type