Unverified Commit 5467ac31 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)

parent 5d7e3d01
...@@ -66,19 +66,6 @@ endif() ...@@ -66,19 +66,6 @@ endif()
# #
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
#
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
# `libtorch_python.so` for linking against an extension. Torch's cmake
# configuration does not include this library (presumably since the cmake
# config is used for standalone C++ binaries that link against torch).
# The `libtorch_python.so` library defines some of the glue code between
# torch/python via pybind and is required by VLLM extensions for this
# reason. So, add it by manually with `find_library` using torch's
# installed library path.
#
find_library(torch_python_LIBRARY torch_python PATHS
"${TORCH_INSTALL_PREFIX}/lib")
# #
# Forward the non-CUDA device extensions to external CMake scripts. # Forward the non-CUDA device extensions to external CMake scripts.
# #
...@@ -171,7 +158,7 @@ set(VLLM_EXT_SRC ...@@ -171,7 +158,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu" "csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu" "csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp") "csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent) include(FetchContent)
...@@ -218,6 +205,7 @@ define_gpu_extension_target( ...@@ -218,6 +205,7 @@ define_gpu_extension_target(
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI) WITH_SOABI)
# #
...@@ -225,7 +213,7 @@ define_gpu_extension_target( ...@@ -225,7 +213,7 @@ define_gpu_extension_target(
# #
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu")
define_gpu_extension_target( define_gpu_extension_target(
...@@ -235,6 +223,7 @@ define_gpu_extension_target( ...@@ -235,6 +223,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC} SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
USE_SABI 3
WITH_SOABI) WITH_SOABI)
# #
...@@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC ...@@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cu" "csrc/punica/punica_ops.cu"
"csrc/punica/punica_pybind.cpp") "csrc/punica/torch_bindings.cpp")
# #
# Copy GPU compilation flags+update for punica # Copy GPU compilation flags+update for punica
...@@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES) ...@@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
SOURCES ${VLLM_PUNICA_EXT_SRC} SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI) WITH_SOABI)
else() else()
message(WARNING "Unable to create _punica_C target because none of the " message(WARNING "Unable to create _punica_C target because none of the "
......
...@@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ...@@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \ pip install -U -r requirements-rocm.txt \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \ && python3 setup.py install \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
&& cd .. && cd ..
......
...@@ -73,7 +73,7 @@ set(VLLM_EXT_SRC ...@@ -73,7 +73,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/cache.cpp" "csrc/cpu/cache.cpp"
"csrc/cpu/layernorm.cpp" "csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp" "csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pybind.cpp") "csrc/cpu/torch_bindings.cpp")
define_gpu_extension_target( define_gpu_extension_target(
_C _C
...@@ -81,10 +81,10 @@ define_gpu_extension_target( ...@@ -81,10 +81,10 @@ define_gpu_extension_target(
LANGUAGE CXX LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS} COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI WITH_SOABI
) )
add_custom_target(default) add_custom_target(default)
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
add_dependencies(default _C) add_dependencies(default _C)
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
file(REAL_PATH ${EXECUTABLE} EXECUTABLE) file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
set(Python_EXECUTABLE ${EXECUTABLE}) set(Python_EXECUTABLE ${EXECUTABLE})
find_package(Python COMPONENTS Interpreter Development.Module) find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
if (NOT Python_FOUND) if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif() endif()
...@@ -294,6 +294,7 @@ endmacro() ...@@ -294,6 +294,7 @@ endmacro()
# INCLUDE_DIRECTORIES <dirs> - Extra include directories. # INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LIBRARIES <libraries> - Extra link libraries. # LIBRARIES <libraries> - Extra link libraries.
# WITH_SOABI - Generate library with python SOABI suffix name. # WITH_SOABI - Generate library with python SOABI suffix name.
# USE_SABI <version> - Use python stable api <version>
# #
# Note: optimization level/debug info is set via cmake build type. # Note: optimization level/debug info is set via cmake build type.
# #
...@@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) ...@@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1 cmake_parse_arguments(PARSE_ARGV 1
GPU GPU
"WITH_SOABI" "WITH_SOABI"
"DESTINATION;LANGUAGE" "DESTINATION;LANGUAGE;USE_SABI"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
# Add hipify preprocessing step when building with HIP/ROCm. # Add hipify preprocessing step when building with HIP/ROCm.
...@@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME) ...@@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
set(GPU_WITH_SOABI) set(GPU_WITH_SOABI)
endif() endif()
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI}) if (GPU_USE_SABI)
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
else()
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
endif()
if (GPU_LANGUAGE STREQUAL "HIP") if (GPU_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step. # Make this target dependent on the hipify preprocessor step.
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cmath> #include <cmath>
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <algorithm> #include <algorithm>
...@@ -809,15 +809,16 @@ void paged_attention_v1( ...@@ -809,15 +809,16 @@ void paged_attention_v1(
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size] value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads] int64_t num_kv_heads, // [num_heads]
float scale, double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int64_t blocksparse_local_blocks,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
...@@ -973,15 +974,16 @@ void paged_attention_v2( ...@@ -973,15 +974,16 @@ void paged_attention_v2(
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size] value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads] int64_t num_kv_heads, // [num_heads]
float scale, double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int64_t blocksparse_local_blocks,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
......
#pragma once #pragma once
#include <torch/extension.h> #include <torch/all.h>
#include <map> #include <map>
#include <vector> #include <vector>
...@@ -8,14 +8,18 @@ ...@@ -8,14 +8,18 @@
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void copy_blocks(std::vector<torch::Tensor>& key_caches, // Note: the key_caches and value_caches vectors are constant but
std::vector<torch::Tensor>& value_caches, // not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const float kv_scale); const std::string& kv_cache_dtype,
const double kv_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
...@@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, ...@@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
// Just for unittest // Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float scale, const std::string& kv_cache_dtype); const double scale, const std::string& kv_cache_dtype);
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
...@@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, ...@@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
} // namespace vllm } // namespace vllm
void copy_blocks(std::vector<torch::Tensor>& key_caches, // Note: the key_caches and value_caches vectors are constant but
std::vector<torch::Tensor>& value_caches, // not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
...@@ -255,7 +258,7 @@ void reshape_and_cache( ...@@ -255,7 +258,7 @@ void reshape_and_cache(
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size] value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const float kv_scale) { const std::string& kv_cache_dtype, const double kv_scale) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
...@@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, ...@@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
// Only for testing. // Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float kv_scale, const std::string& kv_cache_dtype) { const double kv_scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device(); torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device(); torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
......
...@@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher( ...@@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher(
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int64_t blocksparse_local_blocks,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1, TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet."); "CPU backend does not support blocksparse attention yet.");
...@@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher( ...@@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher(
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int64_t blocksparse_local_blocks,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1, TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet."); "CPU backend does not support blocksparse attention yet.");
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches, void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor>& value_caches, std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& mapping_pairs, const torch::Tensor& mapping_pairs,
const int element_num_per_block, const int element_num_per_block,
const int layer_num) { const int layer_num) {
...@@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl( ...@@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl(
} }
}; // namespace }; // namespace
void copy_blocks(std::vector<torch::Tensor>& key_caches, // Note: the key_caches and value_caches vectors are constant but
std::vector<torch::Tensor>& value_caches, // not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size(); unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
...@@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches, ...@@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, float kv_scale) { const std::string& kv_cache_dtype, double kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#define CPU_TYPES_HPP #define CPU_TYPES_HPP
#include <immintrin.h> #include <immintrin.h>
#include <torch/extension.h> #include <torch/all.h>
namespace vec_op { namespace vec_op {
......
...@@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input, ...@@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
} // namespace } // namespace
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
...@@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, ...@@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
} }
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, float epsilon) { torch::Tensor& weight, double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
......
...@@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl( ...@@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl(
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1); int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
......
#include "cache.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def("paged_attention_v1", &paged_attention_v1,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
// Activation ops
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
ops.def("gelu_and_mul", &gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
// Layernorm
ops.def("rms_norm", &rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def("rotary_embedding", &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def("copy_blocks", &copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def("reshape_and_cache", &reshape_and_cache,
"Reshape the key and value tensors and cache them");
}
#include "cache.h"
#include "ops.h"
#include "registration.h"
#include <torch/library.h>
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCPU, &gelu_new);
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm", torch::kCPU, &rms_norm);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float kv_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
#pragma once #pragma once
#include <torch/extension.h> int64_t get_device_attribute(int64_t attribute, int64_t device_id);
int get_device_attribute(int attribute, int device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
int get_max_shared_memory_per_block_device_attribute(int device_id);
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute(int attribute, int device_id) { int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
int device, value; int device, value;
if (device_id < 0) { if (device_id < 0) {
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -14,8 +14,8 @@ int get_device_attribute(int attribute, int device_id) { ...@@ -14,8 +14,8 @@ int get_device_attribute(int attribute, int device_id) {
return value; return value;
} }
int get_max_shared_memory_per_block_device_attribute(int device_id) { int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
int attribute; int64_t attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
......
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <torch/extension.h> #include <torch/all.h>
#include "custom_all_reduce.cuh" #include "custom_all_reduce.cuh"
// fake pointer type // fake pointer type, must match fptr_t type in ops.h
using fptr_t = uint64_t; using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t)); static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int rank, const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = offsets.size(); int world_size = offsets.size();
if (world_size > 8) if (world_size > 8)
...@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) { ...@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink) { bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size(); auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16 // custom allreduce requires input byte size to be multiples of 16
...@@ -125,7 +125,7 @@ void dispose(fptr_t _fa) { ...@@ -125,7 +125,7 @@ void dispose(fptr_t _fa) {
delete fa; delete fa;
} }
int meta_size() { return sizeof(vllm::Signal); } int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
...@@ -134,10 +134,16 @@ void register_buffer(fptr_t _fa, torch::Tensor& t, ...@@ -134,10 +134,16 @@ void register_buffer(fptr_t _fa, torch::Tensor& t,
fa->register_buffer(handles, offsets, t.data_ptr()); fa->register_buffer(handles, offsets, t.data_ptr());
} }
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) { fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
return fa->get_graph_buffer_ipc_meta(); auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
} }
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
*/ */
#pragma once #pragma once
#include <torch/extension.h> #include <torch/all.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
......
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
...@@ -291,7 +291,7 @@ fused_add_rms_norm_kernel( ...@@ -291,7 +291,7 @@ fused_add_rms_norm_kernel(
void rms_norm(torch::Tensor& out, // [..., hidden_size] void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
...@@ -319,7 +319,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] ...@@ -319,7 +319,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment