_custom_ops.py 17.1 KB
Newer Older
1
import contextlib
2
import functools
3
from typing import List, Optional, Tuple, Type
4
5
6

import torch

7
8
9
10
from vllm.logger import init_logger

logger = init_logger(__name__)

11
try:
12
    import vllm._C
13
14
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)
15

16
17
18
19
20
21
22
23
24
25
26
27
with contextlib.suppress(ImportError):
    import vllm._moe_C

with contextlib.suppress(ImportError):
    # ruff: noqa: F401
    import vllm._punica_C


def is_custom_op_supported(op_name: str) -> bool:
    op, overloads = torch._C._jit_get_operation(op_name)
    return op is not None

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def hint_on_error(fn):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except AttributeError as e:
            msg = (
                "Error in calling custom op %s: %s\n"
                "Possibly you have built or installed an obsolete version of vllm.\n"
                "Please try a clean build and install of vllm,"
                "or remove old built files such as vllm/*cpython*.so and build/ ."
            )
            logger.error(msg, fn.__name__, e)
            raise e

    return wrapper


48
49
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
50
    torch.ops._C.silu_and_mul(out, x)
51
52
53


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
54
    torch.ops._C.gelu_and_mul(out, x)
55
56
57


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
58
    torch.ops._C.gelu_tanh_and_mul(out, x)
59
60
61


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
62
    torch.ops._C.gelu_fast(out, x)
63
64
65


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
66
    torch.ops._C.gelu_new(out, x)
67
68
69
70
71
72
73
74
75
76
77


# page attention ops
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
78
    seq_lens: torch.Tensor,
79
    block_size: int,
80
    max_seq_len: int,
81
82
83
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
    kv_scale: float,
84
85
86
87
88
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
89
) -> None:
90
    torch.ops._C.paged_attention_v1(
91
92
93
94
        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
        kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
95
96
97
98
99
100
101
102
103
104
105
106
107


def paged_attention_v2(
    out: torch.Tensor,
    exp_sum: torch.Tensor,
    max_logits: torch.Tensor,
    tmp_out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
108
    seq_lens: torch.Tensor,
109
    block_size: int,
110
    max_seq_len: int,
111
112
113
    alibi_slopes: Optional[torch.Tensor],
    kv_cache_dtype: str,
    kv_scale: float,
114
115
116
117
118
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
119
) -> None:
120
    torch.ops._C.paged_attention_v2(
121
122
123
124
125
        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
        alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
        blocksparse_local_blocks, blocksparse_vert_stride,
        blocksparse_block_size, blocksparse_head_sliding_step)
126
127
128
129
130
131
132
133
134
135
136


# pos encoding ops
def rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
137
138
    torch.ops._C.rotary_embedding(positions, query, key, head_size,
                                  cos_sin_cache, is_neox)
139
140
141
142
143
144
145


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                             key: torch.Tensor, head_size: int,
                             cos_sin_cache: torch.Tensor, is_neox: bool,
                             rot_dim: int,
                             cos_sin_cache_offsets: torch.Tensor) -> None:
146
147
148
    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
                                          cos_sin_cache, is_neox, rot_dim,
                                          cos_sin_cache_offsets)
149
150
151
152
153


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
             epsilon: float) -> None:
154
    torch.ops._C.rms_norm(out, input, weight, epsilon)
155
156
157
158


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, epsilon: float) -> None:
159
    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
160
161
162
163
164
165
166


# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
                   zeros: torch.Tensor, split_k_iters: int, thx: int,
                   thy: int) -> torch.Tensor:
167
168
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
                                       thx, thy)
169
170
171
172


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
173
    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
174
175
176
177
178
179
180


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
181
182
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)
183
184
185
186


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                 bit: int) -> None:
187
    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
188
189
190
191
192


# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
                    lookup_table: torch.Tensor) -> None:
193
    torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
194
195
196
197
198
199


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
                size_n: int, size_k: int) -> torch.Tensor:
200
201
    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
                                    size_n, size_k)
202
203


204
205
206
207
208
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                        b_meta: torch.Tensor, b_scales: torch.Tensor,
                        workspace: torch.Tensor, num_bits: int, size_m: int,
                        size_n: int, size_k: int) -> torch.Tensor:
209
210
211
    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
                                            workspace, num_bits, size_m,
                                            size_n, size_k)
212
213


214
# cutlass
215
216
217
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
                      scale_b: torch.Tensor,
                      out_dtype: Type[torch.dtype]) -> torch.Tensor:
218
219
220
221
222
223
224
    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)

    m = a.shape[0]
    n = b.shape[1]
    out = torch.empty((m, n), dtype=out_dtype, device=a.device)

225
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b)
226
227
228
    return out


229
230
231
232
233
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
              codebooks: torch.Tensor, scales: torch.Tensor,
              codebook_partition_sizes: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
234
235
    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
                                  codebook_partition_sizes, bias)
236
237
238
239


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
                 codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
240
241
    return torch.ops._C.aqlm_dequant(codes, codebooks,
                                     codebook_partition_sizes)
242
243


244
245
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
246
247
                       size_k: int, size_n: int,
                       num_bits: int) -> torch.Tensor:
248
249
    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
                                           num_bits)
250
251
252
253


def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor, g_idx: torch.Tensor,
254
255
                     perm: torch.Tensor, workspace: torch.Tensor,
                     num_bits: int, size_m: int, size_n: int, size_k: int,
256
                     is_k_full: bool) -> torch.Tensor:
257
258
259
    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
                                         workspace, num_bits, size_m, size_n,
                                         size_k, is_k_full)
260
261


262
# fp8
263
264
265
def scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
266
    batch_dim_padding: Optional[int] = None,
267
) -> Tuple[torch.Tensor, torch.Tensor]:
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    """
    Quantize input tensor to FP8 and return quantized tensor and scale.

    This function supports both static and dynamic quantization: If you
    provide the scale, it will use static scaling and if you omit it,
    the scale will be determined dynamically. The function also allows
    optional padding of the output tensor for downstream kernels that
    will benefit from padding.

    Args:
        input: The input tensor to be quantized to FP8
        scale: Optional scaling factor for the FP8 quantization
        batch_dim_padding: If specified, pad the first dimension
            of the output to at least this value.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
            scaling factor.
    """
    if batch_dim_padding:
        shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
        output = torch.empty(shape,
                             device=input.device,
                             dtype=torch.float8_e4m3fn)
    else:
        output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
294
295
    if scale is None:
        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
296
        torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
297
    else:
298
        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
299
300
301
    return output, scale


302
# int8
303
304
305
306
def scaled_int8_quant(
        input: torch.Tensor,
        scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
307
    """
308
    Quantize the input tensor to int8 and return the quantized tensor and scale.
309
310
311

    Args:
        input: The input tensor to be quantized to int8.
312
313
        scale: Optional scaling factor for the int8 quantization.
            When not provided, we invoke dynamic-per-token quantization.
314
315

    Returns:
316
      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
317
    """
318
319
320
    output = torch.empty_like(input, dtype=torch.int8)
    if scale is not None:
        # static-per-tensor quantization.
321
        torch.ops._C.static_scaled_int8_quant(output, input, scale)
322
323
324
325
326
327
        return output, scale

    # dynamic-per-token quantization.
    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
                               device=input.device,
                               dtype=torch.float32)
328
    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
329
    return output, input_scales
330
331


332
333
334
335
336
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
                         block_size: int, sorted_token_ids: torch.Tensor,
                         experts_ids: torch.Tensor,
                         num_tokens_post_pad: torch.Tensor) -> None:
337
338
339
340
341
342
343
344
345
346
    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
                                      sorted_token_ids, experts_ids,
                                      num_tokens_post_pad)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 token_expert_indicies: torch.Tensor,
                 gating_output: float) -> None:
    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
                                  token_expert_indicies, gating_output)
347
348
349
350
351
352
353
354
355
356
357


def reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
    kv_scale: float,
) -> None:
358
359
360
    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                             value_cache, slot_mapping,
                                             kv_cache_dtype, kv_scale)
361
362


363
364
365
366
367
368
369
370
def reshape_and_cache_flash(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
) -> None:
371
372
373
    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
                                                   value_cache, slot_mapping,
                                                   kv_cache_dtype)
374
375


376
377
def copy_blocks(key_caches: List[torch.Tensor],
                value_caches: List[torch.Tensor],
378
                block_mapping: torch.Tensor) -> None:
379
    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
380
381
382


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
383
                block_mapping: torch.Tensor) -> None:
384
    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
385
386


387
388
389
390
def convert_fp8(output: torch.Tensor,
                input: torch.Tensor,
                scale: float = 1.0,
                kv_dtype: str = "fp8") -> None:
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def get_device_attribute(attribute: int, device: int) -> int:
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
    # ruff: noqa: E501
    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
        device)


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                   handles: List[str], offsets: List[int], rank: int,
                   full_nvlink: bool) -> int:
    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
                                                 offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
                     full_nvlink: bool) -> bool:
    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
                                                   full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

421

422
423
424
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
                     out: torch.Tensor) -> None:
    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
425

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483

def dispose(fa: int) -> None:
    torch.ops._C_custom_ar.dispose(fa)


def meta_size() -> int:
    return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
                    offsets: List[int]) -> None:
    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
                           offsets: List[List[int]]) -> None:
    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)


# punica
def dispatch_bgmv(
    y: torch.Tensor,
    x: torch.Tensor,
    w_t_all: torch.Tensor,
    indicies: torch.Tensor,
    layer_idx: int,
    scale: float,
) -> None:
    torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
                                      scale)


def dispatch_bgmv_low_level(
    y: torch.Tensor,
    x: torch.Tensor,
    w_t_all: torch.Tensor,
    indicies: torch.Tensor,
    layer_idx: int,
    scale: float,
    h_in: int,
    h_out: int,
    y_offset: int,
) -> None:
    torch.ops._punica_C.dispatch_bgmv_low_level(
        y,
        x,
        w_t_all,
        indicies,
        layer_idx,
        scale,
        h_in,
        h_out,
        y_offset,
    )
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505


# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
    # find functions that are defined in this file and have torch.Tensor
    # in their annotations. `arg == "torch.Tensor"` is used to handle
    # the case when users use `import __annotations__` to turn type
    # hints into strings.
    if isinstance(v, fn_type) \
        and v.__code__.co_filename == __file__ \
        and any(arg is torch.Tensor or arg == "torch.Tensor"
                   for arg in v.__annotations__.values()):
        names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type