_custom_ops.py 19.4 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import List, Optional, Tuple, Union
4
5
6

import torch

7
8
9
10
from vllm.logger import init_logger

logger = init_logger(__name__)

11
try:
12
    import vllm._C
13
14
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)
15

16
17
with contextlib.suppress(ImportError):
    # ruff: noqa: F401
18
    import vllm._moe_C
19
20
21
22
23
24


def is_custom_op_supported(op_name: str) -> bool:
    op, overloads = torch._C._jit_get_operation(op_name)
    return op is not None

25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


45
46
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
47
    torch.ops._C.silu_and_mul(out, x)
48
49
50


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
51
    torch.ops._C.gelu_and_mul(out, x)
52
53
54


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
55
    torch.ops._C.gelu_tanh_and_mul(out, x)
56
57
58


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
59
    torch.ops._C.gelu_fast(out, x)
60
61
62


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
63
    torch.ops._C.gelu_new(out, x)
64
65


66
67
68
69
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
    torch.ops._C.gelu_quick(out, x)


70
71
72
73
74
75
76
77
78
# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
79
    seq_lens: torch.Tensor,
80
    block_size: int,
81
    max_seq_len: int,
82
83
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
84
85
    k_scale: float,
    v_scale: float,
86
87
88
89
90
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
91
) -> None:
92
    torch.ops._C.paged_attention_v1(
93
94
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
95
96
97
        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
        blocksparse_vert_stride, blocksparse_block_size,
        blocksparse_head_sliding_step)
98
99
100
101
102
103
104
105
106
107
108
109
110


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
111
    seq_lens: torch.Tensor,
112
    block_size: int,
113
    max_seq_len: int,
114
115
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
116
117
    k_scale: float,
    v_scale: float,
118
119
120
121
122
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
123
) -> None:
124
    torch.ops._C.paged_attention_v2(
125
126
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
127
        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
128
129
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
130
131
132
133
134
135
136
137
138
139
140


# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
141
142
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
143
144
145
146
147
148
149


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
150
151
152
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
153
154
155
156
157


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
158
    torch.ops._C.rms_norm(out, input, weight, epsilon)
159
160
161
162


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
163
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
164
165


166
167
168
169
170
171
172
173
174
175
176
177
def advance_step(num_seqs: int, num_queries: int, block_size: int,
                 input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
                 input_positions: torch.Tensor, seq_lens: torch.Tensor,
                 slot_mapping: torch.Tensor,
                 block_tables: torch.Tensor) -> None:
    """Advance a step on GPU for existing inputs for a multi-step runner"""
    return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
                                     input_tokens, sampled_token_ids,
                                     input_positions, seq_lens, slot_mapping,
                                     block_tables)


178
179
180
181
182
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
183
184
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
185
186
187
188


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
189
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
190
191
192
193
194
195
196


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
197
198
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
199
200
201
202


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
203
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
204
205
206
207
208


# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
                    lookup_table: torch.Tensor) -> None:
209
    torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
210
211
212
213
214
215


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
216
217
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
218
219


220
221
222
223
224
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
                        workspace: torch.Tensor, num_bits: int, size_m: int,
                        size_n: int, size_k: int) -> torch.Tensor:
225
226
227
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
                                            workspace, num_bits, size_m,
                                            size_n, size_k)
228
229


230
# cutlass
231
232
233
234
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


235
236
237
def cutlass_scaled_mm(a: torch.Tensor,
                      b: torch.Tensor,
                      scale_a: torch.Tensor,
238
                      scale_b: torch.Tensor,
239
                      out_dtype: torch.dtype,
240
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
241
242
243
244
245
246
247
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

248
249
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

250
251
252
    return out


253
254
255
256
257
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
              codebook_partition_sizes: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
258
259
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
260
261
262
263


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
                 codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
264
265
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
266
267


268
269
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
270
271
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
272
273
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
274
275


276
277
278
279
280
281
# gptq_marlin
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
                      num_bits: int) -> torch.Tensor:
    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


282
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
283
284
285
                     b_scales: torch.Tensor, b_zeros: torch.Tensor,
                     g_idx: torch.Tensor, perm: torch.Tensor,
                     workspace: torch.Tensor, num_bits: int, size_m: int,
286
287
                     size_n: int, size_k: int, is_k_full: bool, has_zp: bool,
                     use_fp32_reduce: bool) -> torch.Tensor:
288
289
290
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
                                         g_idx, perm, workspace, num_bits,
                                         size_m, size_n, size_k, is_k_full,
291
                                         has_zp, use_fp32_reduce)
292
293


294
295
296
297
298
299
300
301
302
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    b_scales: torch.Tensor, workspace: torch.Tensor,
                    num_bits: int, size_m: int, size_n: int,
                    size_k: int) -> torch.Tensor:
    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
                                        num_bits, size_m, size_n, size_k)


303
# fp8
304
305
306
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
307
    num_token_padding: Optional[int] = None,
308
    scale_ub: Optional[torch.Tensor] = None,
309
    use_per_token_if_dynamic: bool = False,
310
) -> Tuple[torch.Tensor, torch.Tensor]:
311
312
313
314
315
316
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
317
    optional padding of the output tensors for downstream kernels that
318
319
320
321
322
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
323
324
        scale_ub: Optional upper bound for scaling factor in dynamic 
            per token case
325
        num_token_padding: If specified, pad the first dimension
326
            of the output to at least this value.
327
328
        use_per_token_if_dynamic: Whether to do per_tensor or per_token 
            in the dynamic quantization case.
329
330
331
332
333

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
334
335
    # This code assumes batch_dim and num_tokens are flattened
    assert (input.ndim == 2)
336
    shape: Union[Tuple[int, int], torch.Size] = input.shape
337
338
339
340
    if num_token_padding:
        shape = (max(num_token_padding, input.shape[0]), shape[1])
    output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)

341
    if scale is None:
342
        if use_per_token_if_dynamic:
343
            scale = torch.empty((shape[0], 1),
344
345
346
                                device=input.device,
                                dtype=torch.float32)
            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
347
                output, input, scale, scale_ub)
348
349
350
        else:
            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
351
    else:
352
353
        # num_token_padding not implemented for this case
        assert (scale.numel() == 1 or num_token_padding is None)
354
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
355

356
    return output, scale
357
358


359
# int8
360
361
362
363
def scaled_int8_quant(
        input: torch.Tensor,
        scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
364
    """
365
    Quantize the input tensor to int8 and return the quantized tensor and scale.
366
367
368

    Args:
        input: The input tensor to be quantized to int8.
369
370
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
371
372

    Returns:
373
      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
374
    """
375
376
377
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
378
        torch.ops._C.static_scaled_int8_quant(output, input, scale)
379
380
381
382
383
384
        return output, scale

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
385
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
386
    return output, input_scales
387
388


389
390
391
392
393
394
395
396
397
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                    s_tok: torch.Tensor, s_ch: torch.Tensor,
                    s_group: torch.Tensor, workspace: torch.Tensor,
                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
                                        workspace, size_m, size_n, size_k)


398
399
400
401
402
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
403
404
405
406
407
408
409
410
411
412
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
413
414
415
416
417
418
419
420
421


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
422
423
    k_scale: float,
    v_scale: float,
424
) -> None:
425
426
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
427
                                             kv_cache_dtype, k_scale, v_scale)
428
429


430
431
432
433
434
435
436
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
437
438
    k_scale: float,
    v_scale: float,
439
) -> None:
440
441
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
442
443
                                                   kv_cache_dtype, k_scale,
                                                   v_scale)
444
445


446
447
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
448
                block_mapping: torch.Tensor) -> None:
449
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
450
451
452


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
453
                block_mapping: torch.Tensor) -> None:
454
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
455
456


457
458
459
460
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
                     full_nvlink: bool) -> bool:
    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
                                                   full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

491

492
493
494
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
495

496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
                   for arg in v.__annotations__.values()):
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type