Unverified Commit 5467ac31 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)

parent 5d7e3d01
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
}
#pragma once #pragma once
#include <torch/extension.h> #include <torch/all.h>
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h" #include "../cuda_compat.h"
......
#include "registration.h"
#include "moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
...@@ -108,8 +108,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ...@@ -108,8 +108,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
} }
} // namespace vllm } // namespace vllm
void moe_align_block_size(torch::Tensor topk_ids, int num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......
#pragma once #pragma once
#include <torch/extension.h> #include <torch/library.h>
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int64_t blocksparse_local_blocks,
const int blocksparse_block_size, const int blocksparse_head_sliding_step); const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int64_t blocksparse_local_blocks,
const int blocksparse_block_size, const int blocksparse_head_sliding_step); const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon); double epsilon);
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, float epsilon); torch::Tensor& weight, double epsilon);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox, torch::Tensor& cos_sin_cache, bool is_neox,
int rot_dim, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets); torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
...@@ -60,12 +62,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, ...@@ -60,12 +62,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, torch::Tensor _scaling_factors, torch::Tensor _zeros,
int split_k_iters); int64_t split_k_iters);
torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, int split_k_iters, int thx, torch::Tensor _zeros, int64_t split_k_iters,
int thy); int64_t thx, int64_t thy);
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor& b_scales, torch::Tensor& workspace,
...@@ -88,7 +90,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -88,7 +90,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits); int64_t num_bits);
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales); torch::Tensor const& b_scales);
...@@ -106,9 +108,9 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, ...@@ -106,9 +108,9 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int bit); bool use_exllama, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
...@@ -116,28 +118,28 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, ...@@ -116,28 +118,28 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
void moe_align_block_size(torch::Tensor topk_ids, int num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = uint64_t; using fptr_t = int64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int rank, const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink); bool full_nvlink);
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink); bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out); torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int meta_size(); int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets); const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa); fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets); const std::vector<std::vector<int64_t>>& offsets);
......
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
...@@ -127,7 +127,7 @@ void rotary_embedding( ...@@ -127,7 +127,7 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] // [num_tokens, num_kv_heads * head_size]
int head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1); int64_t num_tokens = query.numel() / query.size(-1);
...@@ -138,7 +138,7 @@ void rotary_embedding( ...@@ -138,7 +138,7 @@ void rotary_embedding(
int64_t key_stride = key.stride(-2); int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
...@@ -168,9 +168,9 @@ void batched_rotary_embedding( ...@@ -168,9 +168,9 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] // [num_tokens, num_kv_heads * head_size]
int head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int rot_dim, bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens] torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) { ) {
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
...@@ -180,7 +180,7 @@ void batched_rotary_embedding( ...@@ -180,7 +180,7 @@ void batched_rotary_embedding(
int64_t key_stride = key.stride(-2); int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
......
#include <torch/extension.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cstdint> #include <cstdint>
...@@ -88,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, ...@@ -88,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
} }
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale) { torch::Tensor indicies, int64_t layer_idx, double scale) {
CHECK_INPUT(y); CHECK_INPUT(y);
CHECK_INPUT(x); CHECK_INPUT(x);
CHECK_INPUT(w); CHECK_INPUT(w);
...@@ -320,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, ...@@ -320,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out, double scale, int64_t h_in, int64_t h_out,
int64_t y_offset) { int64_t y_offset) {
CHECK_INPUT(y); CHECK_INPUT(y);
CHECK_INPUT(x); CHECK_INPUT(x);
......
#pragma once #pragma once
#include <torch/extension.h> #include <torch/all.h>
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale); torch::Tensor indicies, int64_t layer_idx, double scale);
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out, double scale, int64_t h_in, int64_t h_out,
int64_t y_offset); int64_t y_offset);
#include <torch/extension.h>
#include "punica_ops.h"
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}
#include "registration.h"
#include "punica_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()");
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
m.def(
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
"Tensor indicies, int layer_idx,"
"float scale, int h_in, int h_out,"
"int y_offset) -> ()");
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def("paged_attention_v1", &paged_attention_v1,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
// Activation ops
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
ops.def("gelu_and_mul", &gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
// Layernorm
ops.def("rms_norm", &rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def("rotary_embedding", &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def("batched_rotary_embedding", &batched_rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
"(supports multiple loras)");
// Quantization ops
#ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm,
"Marlin (Dense) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
"gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization.");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
"Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
"Compute FP8 quantized tensor and scaling factor");
ops.def("moe_align_block_size", &moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
"Compute int8 quantized tensor and scaling factor");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def("copy_blocks", &copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def("reshape_and_cache", &reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
"Reshape the key and value tensors and cache them");
cache_ops.def("convert_fp8", &convert_fp8,
"Convert the key and value cache to fp8 data type");
// Cuda utils
pybind11::module cuda_utils =
m.def_submodule("cuda_utils", "vLLM cuda utils");
cuda_utils.def("get_device_attribute", &get_device_attribute,
"Gets the specified device attribute.");
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
custom_ar.def("dispose", &dispose, "dispose");
custom_ar.def("meta_size", &meta_size, "meta_size");
custom_ar.def("register_buffer", &register_buffer, "register_buffer");
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
}
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h> #include <torch/all.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
......
...@@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} ...@@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
} }
*/ */
#include <torch/extension.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh" #include "dequantize.cuh"
...@@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64) ...@@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64)
torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, int split_k_iters, int thx, torch::Tensor _zeros, int64_t split_k_iters,
int thy) { int64_t thx, int64_t thy) {
int in_c = _kernel.size(0); int in_c = _kernel.size(0);
int qout_c = _kernel.size(1); int qout_c = _kernel.size(1);
int out_c = qout_c * 8; int out_c = qout_c * 8;
...@@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, ...@@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, torch::Tensor _scaling_factors, torch::Tensor _zeros,
int split_k_iters) { int64_t split_k_iters) {
int num_in_feats = _in_feats.size(0); int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1); int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/all.h>
#include <cmath> #include <cmath>
#include "../../dispatch_utils.h" #include "../../dispatch_utils.h"
......
#include <stddef.h> #include <stddef.h>
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include <torch/extension.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
......
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h> #include <torch/all.h>
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cmath> #include <cmath>
......
...@@ -6,7 +6,7 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa ...@@ -6,7 +6,7 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include <torch/extension.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -1823,7 +1823,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, ...@@ -1823,7 +1823,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int bit) { bool use_exllama, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
...@@ -1845,7 +1845,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -1845,7 +1845,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c; return c;
} }
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) { void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight( vllm::gptq::shuffle_exllama_weight(
(uint32_t*)q_weight.data_ptr(), (uint32_t*)q_weight.data_ptr(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment