Unverified Commit 5467ac31 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)

parent 5d7e3d01
#pragma once #pragma once
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
......
#include <torch/all.h> #include <torch/all.h>
#include <torch/python.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
#pragma once
#include <Python.h>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include "registration.h"
#include <torch/library.h>
// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops.def(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
ops.def("aqlm_gemm", &aqlm_gemm);
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
// Decompression method for AQLM.
ops.def("aqlm_dequant", &aqlm_dequant);
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
// Quantized GEMM for AWQ.
ops.def("awq_gemm", &awq_gemm);
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ.
ops.def("awq_dequantize", &awq_dequantize);
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
ops.def("marlin_gemm", &marlin_gemm);
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
// gptq_marlin repack from GPTQ.
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_dq(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()");
ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq);
#endif
// Quantized GEMM for GPTQ.
ops.def("gptq_gemm", &gptq_gemm);
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ.
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Quantized GEMM for SqueezeLLM.
ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
"lookup_table) -> ()");
ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
// Compute FP8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
ops.def(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
"()");
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float kv_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()");
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
"kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
// Gets the specified device attribute.
cuda_utils.def("get_device_attribute", &get_device_attribute);
cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
// Gets the maximum shared memory per block device attribute.
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute);
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
torch::kCUDA,
&get_max_shared_memory_per_block_device_attribute);
}
#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def("init_custom_ar", &init_custom_ar);
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def("should_custom_ar", &should_custom_ar);
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
custom_ar.def(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()");
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
custom_ar.def("dispose", &dispose);
custom_ar.impl("dispose", torch::kCPU, &dispose);
custom_ar.def("meta_size", &meta_size);
custom_ar.impl("meta_size", torch::kCPU, &meta_size);
custom_ar.def("register_buffer", &register_buffer);
custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
&get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.impl("register_graph_buffers", torch::kCPU,
&register_graph_buffers);
}
#endif
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
...@@ -60,7 +60,7 @@ def remove_prefix(text, prefix): ...@@ -60,7 +60,7 @@ def remove_prefix(text, prefix):
class CMakeExtension(Extension): class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
super().__init__(name, sources=[], **kwa) super().__init__(name, sources=[], py_limited_api=True, **kwa)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
......
import pytest import pytest
import torch import torch
from vllm._C import ops # ruff: noqa: F401
import vllm._C
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
...@@ -33,7 +34,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, ...@@ -33,7 +34,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out) torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out)
assert torch.allclose(scales_out, scales) assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out, assert torch.allclose(torch_out, ops_out,
...@@ -60,6 +61,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, ...@@ -60,6 +61,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
out2 = torch.empty_like(x, dtype=torch.int8) out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
ops.static_scaled_int8_quant(out2, x, scale_argument) torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2, assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors atol=1) # big atol to account for rounding errors
from typing import Optional, Tuple, Type import contextlib
from typing import List, Optional, Tuple, Type
import torch import torch
try: try:
from vllm._C import cache_ops as vllm_cache_ops import vllm._C
from vllm._C import ops as vllm_ops
except ImportError as e: except ImportError as e:
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError):
import vllm._moe_C
with contextlib.suppress(ImportError):
# ruff: noqa: F401
import vllm._punica_C
def is_custom_op_supported(op_name: str) -> bool:
op, overloads = torch._C._jit_get_operation(op_name)
return op is not None
# activation ops # activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.silu_and_mul(out, x) torch.ops._C.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_and_mul(out, x) torch.ops._C.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_tanh_and_mul(out, x) torch.ops._C.gelu_tanh_and_mul(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_fast(out, x) torch.ops._C.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_new(out, x) torch.ops._C.gelu_new(out, x)
# page attention ops # page attention ops
...@@ -53,7 +65,7 @@ def paged_attention_v1( ...@@ -53,7 +65,7 @@ def paged_attention_v1(
blocksparse_block_size: int = 64, blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0, blocksparse_head_sliding_step: int = 0,
) -> None: ) -> None:
vllm_ops.paged_attention_v1( torch.ops._C.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
...@@ -83,7 +95,7 @@ def paged_attention_v2( ...@@ -83,7 +95,7 @@ def paged_attention_v2(
blocksparse_block_size: int = 64, blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0, blocksparse_head_sliding_step: int = 0,
) -> None: ) -> None:
vllm_ops.paged_attention_v2( torch.ops._C.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
...@@ -100,8 +112,8 @@ def rotary_embedding( ...@@ -100,8 +112,8 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox: bool, is_neox: bool,
) -> None: ) -> None:
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, torch.ops._C.rotary_embedding(positions, query, key, head_size,
is_neox) cos_sin_cache, is_neox)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
...@@ -109,7 +121,7 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, ...@@ -109,7 +121,7 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache: torch.Tensor, is_neox: bool, cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int, rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None: cos_sin_cache_offsets: torch.Tensor) -> None:
vllm_ops.batched_rotary_embedding(positions, query, key, head_size, torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim, cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets) cos_sin_cache_offsets)
...@@ -117,12 +129,12 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, ...@@ -117,12 +129,12 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
# layer norm ops # layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None: epsilon: float) -> None:
vllm_ops.rms_norm(out, input, weight, epsilon) torch.ops._C.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None: weight: torch.Tensor, epsilon: float) -> None:
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# quantization ops # quantization ops
...@@ -130,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, ...@@ -130,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int, zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor: thy: int) -> torch.Tensor:
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thy) thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# gptq # gptq
...@@ -144,26 +156,26 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -144,26 +156,26 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool, b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor: bit: int) -> torch.Tensor:
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None: bit: int) -> None:
vllm_ops.gptq_shuffle(q_weight, q_perm, bit) torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm # squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None: lookup_table: torch.Tensor) -> None:
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin # marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor: size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k) size_n, size_k)
...@@ -172,9 +184,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -172,9 +184,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int, workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int) -> torch.Tensor: size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m, size_n, workspace, num_bits, size_m,
size_k) size_n, size_k)
# cutlass # cutlass
...@@ -188,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, ...@@ -188,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1] n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device) out = torch.empty((m, n), dtype=out_dtype, device=a.device)
vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
return out return out
...@@ -198,20 +210,21 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, ...@@ -198,20 +210,21 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: torch.Tensor, codebook_partition_sizes: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias) codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: torch.Tensor) -> torch.Tensor: codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) return torch.ops._C.aqlm_dequant(codes, codebooks,
codebook_partition_sizes)
# gptq_marlin # gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int, size_k: int, size_n: int,
num_bits: int) -> torch.Tensor: num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits) num_bits)
...@@ -220,7 +233,7 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -220,7 +233,7 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, size_k: int, num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor: is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, num_bits, size_m, size_n, workspace, num_bits, size_m, size_n,
size_k, is_k_full) size_k, is_k_full)
...@@ -259,9 +272,9 @@ def scaled_fp8_quant( ...@@ -259,9 +272,9 @@ def scaled_fp8_quant(
output = torch.empty_like(input, dtype=torch.float8_e4m3fn) output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None: if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32) scale = torch.zeros(1, device=input.device, dtype=torch.float32)
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else: else:
vllm_ops.static_scaled_fp8_quant(output, input, scale) torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale return output, scale
...@@ -284,14 +297,14 @@ def scaled_int8_quant( ...@@ -284,14 +297,14 @@ def scaled_int8_quant(
output = torch.empty_like(input, dtype=torch.int8) output = torch.empty_like(input, dtype=torch.int8)
if scale is not None: if scale is not None:
# static-per-tensor quantization. # static-per-tensor quantization.
vllm_ops.static_scaled_int8_quant(output, input, scale) torch.ops._C.static_scaled_int8_quant(output, input, scale)
return output, scale return output, scale
# dynamic-per-token quantization. # dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1), input_scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device, device=input.device,
dtype=torch.float32) dtype=torch.float32)
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales) torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
return output, input_scales return output, input_scales
...@@ -300,11 +313,18 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, ...@@ -300,11 +313,18 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor, block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor, experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None: num_tokens_post_pad: torch.Tensor) -> None:
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids, sorted_token_ids, experts_ids,
num_tokens_post_pad) num_tokens_post_pad)
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor,
gating_output: float) -> None:
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
token_expert_indicies, gating_output)
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
...@@ -314,8 +334,9 @@ def reshape_and_cache( ...@@ -314,8 +334,9 @@ def reshape_and_cache(
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, kv_scale: float,
) -> None: ) -> None:
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
slot_mapping, kv_cache_dtype, kv_scale) value_cache, slot_mapping,
kv_cache_dtype, kv_scale)
def reshape_and_cache_flash( def reshape_and_cache_flash(
...@@ -326,25 +347,115 @@ def reshape_and_cache_flash( ...@@ -326,25 +347,115 @@ def reshape_and_cache_flash(
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
) -> None: ) -> None:
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
slot_mapping, kv_cache_dtype) value_cache, slot_mapping,
kv_cache_dtype)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None: block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor, def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None: block_mapping: torch.Tensor) -> None:
vllm_cache_ops.swap_blocks(src, dst, block_mapping) torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
def convert_fp8(output: torch.Tensor, def convert_fp8(output: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
scale: float = 1.0, scale: float = 1.0,
kv_dtype: str = "fp8") -> None: kv_dtype: str = "fp8") -> None:
vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# ruff: noqa: E501
return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
device)
# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
handles: List[str], offsets: List[int], rank: int,
full_nvlink: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
offsets, rank, full_nvlink)
def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
full_nvlink: bool) -> bool:
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
full_nvlink)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
#TODO: cuda_utils, custom_ar def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
offsets: List[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: List[str],
offsets: List[List[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
# punica
def dispatch_bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
) -> None:
torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
scale)
def dispatch_bgmv_low_level(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.Tensor,
layer_idx: int,
scale: float,
h_in: int,
h_out: int,
y_offset: int,
) -> None:
torch.ops._punica_C.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
h_in,
h_out,
y_offset,
)
...@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type ...@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm._C import cache_ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata)
...@@ -47,11 +47,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -47,11 +47,11 @@ class FlashAttentionBackend(AttentionBackend):
) -> None: ) -> None:
src_key_cache = src_kv_cache[0] src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0] dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1] src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1] dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
...@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
) -> None: ) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches] key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass @dataclass
...@@ -285,7 +285,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -285,7 +285,7 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. # not cached. This happens during the initial memory profiling run.
cache_ops.reshape_and_cache_flash( ops.reshape_and_cache_flash(
key, key,
value, value,
key_cache, key_cache,
......
...@@ -6,6 +6,7 @@ import torch.distributed as dist ...@@ -6,6 +6,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import ( from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check) gpu_p2p_access_check)
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -15,7 +16,11 @@ from vllm.logger import init_logger ...@@ -15,7 +16,11 @@ from vllm.logger import init_logger
try: try:
import pynvml import pynvml
from vllm._C import custom_ar # Simulate ImportError if custom_ar ops are not supported.
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
raise ImportError("custom_ar", __file__)
custom_ar = True
@contextmanager @contextmanager
def _nvml(): def _nvml():
...@@ -27,7 +32,7 @@ try: ...@@ -27,7 +32,7 @@ try:
except ImportError: except ImportError:
# For AMD GPUs # For AMD GPUs
custom_ar = None custom_ar = False
pynvml = None pynvml = None
@contextmanager @contextmanager
...@@ -97,7 +102,7 @@ class CustomAllreduce: ...@@ -97,7 +102,7 @@ class CustomAllreduce:
self._IS_CAPTURING = False self._IS_CAPTURING = False
self.disabled = True self.disabled = True
if custom_ar is None: if not custom_ar:
# disable because of missing custom allreduce library # disable because of missing custom allreduce library
# e.g. in a non-cuda environment # e.g. in a non-cuda environment
return return
...@@ -175,7 +180,7 @@ class CustomAllreduce: ...@@ -175,7 +180,7 @@ class CustomAllreduce:
# meta data composes of two parts: meta data for synchronization # meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate # (256 bytes) and a temporary buffer for storing intermediate
# allreduce results. # allreduce results.
self.meta = torch.zeros(custom_ar.meta_size() + max_size, self.meta = torch.zeros(ops.meta_size() + max_size,
dtype=torch.uint8, dtype=torch.uint8,
device=self.device) device=self.device)
# This is a pre-registered IPC buffer. In eager mode, input tensors # This is a pre-registered IPC buffer. In eager mode, input tensors
...@@ -196,9 +201,8 @@ class CustomAllreduce: ...@@ -196,9 +201,8 @@ class CustomAllreduce:
self.world_size = world_size self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta) handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = full_nvlink self.full_nvlink = full_nvlink
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles,
handles, offsets, rank, offsets, rank, self.full_nvlink)
self.full_nvlink)
self.register_buffer(self.buffer) self.register_buffer(self.buffer)
@contextmanager @contextmanager
...@@ -252,16 +256,16 @@ class CustomAllreduce: ...@@ -252,16 +256,16 @@ class CustomAllreduce:
def register_buffer(self, inp: torch.Tensor): def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp) handles, offsets = self._get_ipc_meta(inp)
custom_ar.register_buffer(self._ptr, inp, handles, offsets) ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self): def register_graph_buffers(self):
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset)) logger.info("Registering %d cuda graph addresses", len(offset))
custom_ar.register_graph_buffers(self._ptr, handles, offsets) ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor): def should_custom_ar(self, inp: torch.Tensor):
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, return ops.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink) self.full_nvlink)
# all reduce, assuming inp tensor is IPC registered with register_buffer, # all reduce, assuming inp tensor is IPC registered with register_buffer,
...@@ -269,14 +273,14 @@ class CustomAllreduce: ...@@ -269,14 +273,14 @@ class CustomAllreduce:
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.empty_like(inp)
custom_ar.all_reduce_reg(self._ptr, inp, out) ops.all_reduce_reg(self._ptr, inp, out)
return out return out
# all reduce, assuming inp tensor is NOT IPC registered # all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.empty_like(inp)
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
...@@ -304,7 +308,7 @@ class CustomAllreduce: ...@@ -304,7 +308,7 @@ class CustomAllreduce:
def close(self): def close(self):
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
custom_ar.dispose(self._ptr) ops.dispose(self._ptr)
self._ptr = 0 self._ptr = 0
def __del__(self): def __del__(self):
......
...@@ -4,16 +4,21 @@ from typing import Optional ...@@ -4,16 +4,21 @@ from typing import Optional
import torch import torch
from vllm import _custom_ops as ops
def _check_punica_support():
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
return
def _raise_import_error(e):
if torch.cuda.get_device_capability() < (8, 0): if torch.cuda.get_device_capability() < (8, 0):
raise ImportError( raise ImportError(
"punica LoRA kernels require compute capability >= 8.0") from e "punica LoRA kernels require compute capability >= 8.0")
else: else:
raise ImportError( raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM " "punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from e "was set.")
def bgmv( def bgmv(
...@@ -41,12 +46,9 @@ def bgmv( ...@@ -41,12 +46,9 @@ def bgmv(
layer_idx: Layer index of the weight matrices. layer_idx: Layer index of the weight matrices.
scale: Scaling factor. scale: Scaling factor.
""" """
try: _check_punica_support()
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
...@@ -75,11 +77,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, ...@@ -75,11 +77,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
y_offset: Offset to apply to the starting column of y. y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice. y_slice_size: Size of the y column slice.
""" """
try: _check_punica_support()
import vllm._punica_C as punica_kernels
except ImportError as e: ops.dispatch_bgmv_low_level(
_raise_import_error(e)
punica_kernels.dispatch_bgmv_low_level(
y, y,
x, x,
w_t_all, w_t_all,
...@@ -122,10 +122,7 @@ def add_lora(y: torch.Tensor, ...@@ -122,10 +122,7 @@ def add_lora(y: torch.Tensor,
scale: Scaling factor. scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer. buffer: Optional. Shape: `[B, R]`. Temporary buffer.
""" """
try: _check_punica_support()
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1) r = wb_t_all.size(-1)
if buffer is None: if buffer is None:
...@@ -135,9 +132,8 @@ def add_lora(y: torch.Tensor, ...@@ -135,9 +132,8 @@ def add_lora(y: torch.Tensor,
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale)
scale)
def add_lora_slice(y: torch.Tensor, def add_lora_slice(y: torch.Tensor,
...@@ -176,10 +172,7 @@ def add_lora_slice(y: torch.Tensor, ...@@ -176,10 +172,7 @@ def add_lora_slice(y: torch.Tensor,
y_offset: Offset to apply to the starting column of y. y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice. y_slice_size: Size of the y column slice.
""" """
try: _check_punica_support()
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1) r = wb_t_all.size(-1)
if buffer is None: if buffer is None:
...@@ -189,7 +182,7 @@ def add_lora_slice(y: torch.Tensor, ...@@ -189,7 +182,7 @@ def add_lora_slice(y: torch.Tensor,
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
punica_kernels.dispatch_bgmv_low_level( ops.dispatch_bgmv_low_level(
buffer, buffer,
x, x,
wa_t_all, wa_t_all,
...@@ -200,7 +193,7 @@ def add_lora_slice(y: torch.Tensor, ...@@ -200,7 +193,7 @@ def add_lora_slice(y: torch.Tensor,
buffer.size(1), buffer.size(1),
0, 0,
) )
punica_kernels.dispatch_bgmv_low_level( ops.dispatch_bgmv_low_level(
y, y,
buffer, buffer,
wb_t_all, wb_t_all,
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
import vllm._moe_C as moe_kernels
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -355,7 +354,7 @@ def fused_topk( ...@@ -355,7 +354,7 @@ def fused_topk(
topk, topk,
dtype=torch.int32, dtype=torch.int32,
device=hidden_states.device) device=hidden_states.device)
moe_kernels.topk_softmax( ops.topk_softmax(
topk_weights, topk_weights,
topk_ids, topk_ids,
token_expert_indicies, token_expert_indicies,
......
...@@ -22,6 +22,7 @@ import psutil ...@@ -22,6 +22,7 @@ import psutil
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T") T = TypeVar("T")
...@@ -148,12 +149,8 @@ def is_neuron() -> bool: ...@@ -148,12 +149,8 @@ def is_neuron() -> bool:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from vllm._C import cuda_utils
max_shared_mem = ( max_shared_mem = (
cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu)) ops.get_max_shared_memory_per_block_device_attribute(gpu))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail # will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero" assert max_shared_mem > 0, "max_shared_mem can not be zero"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment