import contextlib from typing import List, Optional, Tuple, Type import torch try: import gptq_kernels except ImportError as e: raise RuntimeError("Failed to import gptq_kernel with, Please install gptq_kernels from csrc/quantization/gptq ") try: import vllm._C except ImportError as e: from vllm.logger import init_logger logger = init_logger(__name__) logger.warning("Failed to import from vllm._C with %r", e) with contextlib.suppress(ImportError): import vllm._moe_C with contextlib.suppress(ImportError): # ruff: noqa: F401 import vllm._punica_C def is_custom_op_supported(op_name: str) -> bool: op, overloads = torch._C._jit_get_operation(op_name) return op is not None # activation ops def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.silu_and_mul(out, x) def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_and_mul(out, x) def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_tanh_and_mul(out, x) def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_fast(out, x) def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_new(out, x) # page attention ops def paged_attention_v1( out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) def paged_attention_v2( out: torch.Tensor, exp_sum: torch.Tensor, max_logits: torch.Tensor, tmp_out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) # pos encoding ops def rotary_embedding( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: torch.ops._C.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: torch.ops._C.rms_norm(out, input, weight, epsilon) def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, bit: int) -> torch.Tensor: return gptq_kernels.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit) # return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, # b_g_idx, use_exllama, bit) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: gptq_kernels.gptq_shuffle(q_weight, q_perm, bit) # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # trans_w16 def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor, row:int, col:int) -> None : torch.ops._C.trans_w16_gemm(dst,src,row,col) # squeezellm def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, lookup_table: torch.Tensor) -> None: torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table) # marlin def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, size_n, size_k) # marlin_24 def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, workspace, num_bits, size_m, size_n, size_k) # cutlass def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype]) -> torch.Tensor: assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) m = a.shape[0] n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) return out # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, codebook_partition_sizes, bias) def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: torch.Tensor) -> torch.Tensor: return torch.ops._C.aqlm_dequant(codes, codebooks, codebook_partition_sizes) # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int, is_k_full: bool) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, workspace, num_bits, size_m, size_n, size_k, is_k_full) # fp8 # def scaled_fp8_quant( # input: torch.Tensor, # scale: Optional[torch.Tensor] = None, # batch_dim_padding: Optional[int] = None, # ) -> Tuple[torch.Tensor, torch.Tensor]: # """ # Quantize input tensor to FP8 and return quantized tensor and scale. # This function supports both static and dynamic quantization: If you # provide the scale, it will use static scaling and if you omit it, # the scale will be determined dynamically. The function also allows # optional padding of the output tensor for downstream kernels that # will benefit from padding. # Args: # input: The input tensor to be quantized to FP8 # scale: Optional scaling factor for the FP8 quantization # batch_dim_padding: If specified, pad the first dimension # of the output to at least this value. # Returns: # Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and # scaling factor. # """ # if batch_dim_padding: # shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) # output = torch.empty(shape, # device=input.device, # dtype=torch.float8_e4m3fn) # else: # output = torch.empty_like(input, dtype=torch.float8_e4m3fn) # if scale is None: # scale = torch.zeros(1, device=input.device, dtype=torch.float32) # torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) # else: # torch.ops._C.static_scaled_fp8_quant(output, input, scale) # return output, scale # int8 def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize the input tensor to int8 and return the quantized tensor and scale. Args: input: The input tensor to be quantized to int8. scale: Optional scaling factor for the int8 quantization. When not provided, we invoke dynamic-per-token quantization. Returns: Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. torch.ops._C.static_scaled_int8_quant(output, input, scale) return output, scale # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) return output, input_scales # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size, sorted_token_ids, experts_ids, num_tokens_post_pad) def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies: torch.Tensor, gating_output: float) -> None: torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output) def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, kv_scale: float, ) -> None: torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, kv_scale) def reshape_and_cache_flash( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype) def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) def convert_fp8(output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8") -> None: torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # ruff: noqa: E501 return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( device) # custom ar def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, handles: List[str], offsets: List[int], rank: int, full_nvlink: bool) -> int: return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, offsets, rank, full_nvlink) def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, full_nvlink: bool) -> bool: return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, full_nvlink) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor) -> None: torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: torch.ops._C_custom_ar.dispose(fa) def meta_size() -> int: return torch.ops._C_custom_ar.meta_size() def register_buffer(fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]) -> None: return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) def register_graph_buffers(fa: int, handles: List[str], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) # punica def dispatch_bgmv( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, indicies: torch.Tensor, layer_idx: int, scale: float, ) -> None: torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) def dispatch_bgmv_low_level( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, indicies: torch.Tensor, layer_idx: int, scale: float, h_in: int, h_out: int, y_offset: int, ) -> None: torch.ops._punica_C.dispatch_bgmv_low_level( y, x, w_t_all, indicies, layer_idx, scale, h_in, h_out, y_offset, )