Unverified Commit 6b45a21d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Reorganize c++ source files in sgl-kernel with multiple folders (#4025)

parent a7000a76
...@@ -80,6 +80,12 @@ nvcc_flags = [ ...@@ -80,6 +80,12 @@ nvcc_flags = [
"-std=c++17", "-std=c++17",
"-use_fast_math", "-use_fast_math",
"-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_F16",
"-DCUTLASS_VERSIONS_GENERATED",
"-DCUTE_USE_PACKED_TUPLE=1",
"-DCUTLASS_TEST_LEVEL=0",
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"--ptxas-options=-v",
"-Xcompiler=-Wconversion", "-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing", "-Xcompiler=-fno-strict-aliasing",
] ]
...@@ -91,18 +97,18 @@ nvcc_flags_fp8 = [ ...@@ -91,18 +97,18 @@ nvcc_flags_fp8 = [
sources = [ sources = [
"src/sgl-kernel/torch_extension.cc", "src/sgl-kernel/torch_extension.cc",
"src/sgl-kernel/csrc/trt_reduce_internal.cu", "src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu",
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
"src/sgl-kernel/csrc/speculative/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative_sampling.cu",
"src/sgl-kernel/csrc/per_token_group_quant_fp8.cu",
"src/sgl-kernel/csrc/cublas_grouped_gemm.cu",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu", "3rdparty/flashinfer/csrc/norm.cu",
......
...@@ -43,8 +43,8 @@ include_dirs = [ ...@@ -43,8 +43,8 @@ include_dirs = [
sources = [ sources = [
"src/sgl-kernel/torch_extension_rocm.cc", "src/sgl-kernel/torch_extension_rocm.cc",
"src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip",
"src/sgl-kernel/csrc/custom_all_reduce.hip", "src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
] ]
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <speculative_sampling.cuh>
#include "pytorch_extension_utils.h" #include "pytorch_extension_utils.h"
#include "speculative_sampling.cuh"
using namespace flashinfer; using namespace flashinfer;
......
...@@ -35,7 +35,24 @@ limitations under the License. ...@@ -35,7 +35,24 @@ limitations under the License.
} }
using fptr_t = int64_t; using fptr_t = int64_t;
/*
* From csrc/activation
*/
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
int64_t cuda_stream);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
/*
* From csrc/allreduce
*/
#ifdef USE_ROCM #ifdef USE_ROCM
// ROCM custom allreduce
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink); const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
...@@ -50,7 +67,7 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, ...@@ -50,7 +67,7 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#else #else
// trt_reduce // TRTLLM custom allreduce
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers, fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out); const std::vector<fptr_t>& barrier_out);
...@@ -61,115 +78,84 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& ...@@ -61,115 +78,84 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>&
const std::vector<std::vector<int64_t>>& offsets); const std::vector<std::vector<int64_t>>& offsets);
#endif #endif
// moe_align_block_size /*
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, * From csrc/gemm
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, */
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
// int8_scaled_mm
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias); const c10::optional<torch::Tensor>& bias);
// fp8_scaled_mm
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias); const c10::optional<torch::Tensor>& bias);
// fp8_blockwise_scaled_mm
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const torch::Dtype& out_dtype); const torch::Dtype& out_dtype);
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size,
double eps, double fp8_min, double fp8_max);
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype,
int64_t cublas_handle, int64_t cuda_stream);
// lightning_attention_decode /*
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, * From csrc/moe
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, */
torch::Tensor new_kv); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
// rms norm torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
// fused rms norm
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
// gemma rms norm
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
// fused gemma rms norm
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
int64_t cuda_stream);
// silu and mul /*
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); * From csrc/speculative
*/
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling,
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs,
bool deterministic = true, int64_t cuda_stream = 0);
// gelu tanh and mul void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
int64_t depth, int64_t draft_token_num);
// gelu and mul void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
int64_t depth, int64_t draft_token_num);
// bmm fp8 /*
* From FlashInfer
*/
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
// min p sampling from probs
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic, std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic,
int64_t cuda_stream); int64_t cuda_stream);
// top k renorm probs // top k renorm probs
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. // patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
unsigned int top_k_val, int64_t cuda_stream); unsigned int top_k_val, int64_t cuda_stream);
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. // patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
// wrapper for binding
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
int64_t cuda_stream) { int64_t cuda_stream) {
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream); top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
} }
// top p renorm probs
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val, int64_t cuda_stream); double top_p_val, int64_t cuda_stream);
// top k top p sampling from probs
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val, at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
int64_t cuda_stream); int64_t cuda_stream);
// top p sampling from probs
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
int64_t cuda_stream); int64_t cuda_stream);
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave,
int64_t cuda_stream); int64_t cuda_stream);
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, /*
at::Tensor accept_token_num, // mutable * Other
at::Tensor candidates, at::Tensor retrive_index, */
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
bool deterministic = true, int64_t cuda_stream = 0); torch::Tensor new_kv);
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
int64_t depth, int64_t draft_token_num);
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
int64_t depth, int64_t draft_token_num);
// sgl_per_token_group_quant_fp8
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size,
double eps, double fp8_min, double fp8_max);
// cublas grouped gemm
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype,
int64_t cublas_handle, int64_t cuda_stream);
...@@ -19,7 +19,33 @@ limitations under the License. ...@@ -19,7 +19,33 @@ limitations under the License.
#include "sgl_kernels_ops.h" #include "sgl_kernels_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernels, m) { TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
// trt_reduce /*
* From csrc/activation
*/
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
/*
* From csrc/allreduce
*/
m.def( m.def(
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
"barrier_in, int[] barrier_out) -> int"); "barrier_in, int[] barrier_out) -> int");
...@@ -36,141 +62,112 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -36,141 +62,112 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
m.impl("register_graph_buffers", torch::kCUDA, &register_graph_buffers); m.impl("register_graph_buffers", torch::kCUDA, &register_graph_buffers);
// moe_align_block_size /*
m.def( * From csrc/gemm
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " */
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// int8_scaled_mm
m.def( m.def(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor"); "bias) -> Tensor");
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
// fp8_scaled_mm
m.def( m.def(
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor"); "bias) -> Tensor");
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm); m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
// fp8_blockwise_scaled_mm
m.def( m.def(
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> " "fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
"Tensor"); "Tensor");
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
// lightning_attention_decode
m.def( m.def(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
"new_kv) -> ()"); " float eps, float fp8_min, float fp8_max) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
// rms norm
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
// fused rms norm m.def(
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
// gemma rms norm /*
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); * From csrc/moe
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); */
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// fused gemma rms norm m.def(
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); "new_kv) -> ()");
// silu and mul /*
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); * From csrc/speculative
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); */
m.def(
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
"bool deterministic, int cuda_stream) -> ()");
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
// gelu tanh and mul m.def(
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
"retrive_next_sibling, "
"int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
// gelu and mul m.def(
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); "build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
"int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel);
// bmm fp8 /*
* From FlashInfer
*/
m.def( m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()"); "cublas_handle, int cuda_stream) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
// min p sampling from probs
m.def( m.def(
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, int cuda_stream) -> ()"); "min_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
// top k renorm probs
m.def( m.def(
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()"); "cuda_stream) -> ()");
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
// top p renorm probs
m.def( m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()"); "cuda_stream) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
// top k top p sampling from probs
m.def( m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"cuda_stream) -> ()"); "cuda_stream) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
// top p sampling from probs
m.def( m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
// apply rope with cos sin cache
m.def( m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); "Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
// tree spec decode /*
m.def( * Other
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " */
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
"bool deterministic, int cuda_stream) -> ()");
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
// eagle build tree
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
"retrive_next_sibling, "
"int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
// eagle build tree
m.def(
"build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
"int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel);
// per_token_group_quant_fp8
m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
// cublas grouped gemm
m.def(
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
} }
REGISTER_EXTENSION(_kernels) REGISTER_EXTENSION(_kernels)
...@@ -19,7 +19,9 @@ limitations under the License. ...@@ -19,7 +19,9 @@ limitations under the License.
#include "sgl_kernels_ops.h" #include "sgl_kernels_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernels, m) { TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
// Custom all-reduce kernels /*
* From csrc/allreduce
*/
m.def( m.def(
"init_custom_ar(Tensor meta, Tensor rank_data, " "init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, " "str[] handles, int[] offsets, int rank, "
...@@ -45,12 +47,16 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -45,12 +47,16 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers); m.def("register_graph_buffers", &register_graph_buffers);
m.def("allocate_meta_buffer", &allocate_meta_buffer); m.def("allocate_meta_buffer", &allocate_meta_buffer);
m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer); m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle); m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
// moe_align_block_size /*
* From csrc/moe
*/
m.def( m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
......
import ctypes import ctypes
import logging import logging
import os
import random import random
import socket import socket
import time import time
import unittest import unittest
from typing import Any, List, Optional, Union from typing import Any, List, Optional
import ray import ray
import torch import torch
...@@ -115,7 +114,7 @@ class TestCustomAllReduce(unittest.TestCase): ...@@ -115,7 +114,7 @@ class TestCustomAllReduce(unittest.TestCase):
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.rank_data = torch.empty( self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
) )
self.custom_ptr = custom_ops.init_custom_reduce( self.custom_ptr = custom_ops.init_custom_reduce(
...@@ -148,7 +147,7 @@ class TestCustomAllReduce(unittest.TestCase): ...@@ -148,7 +147,7 @@ class TestCustomAllReduce(unittest.TestCase):
self.vllm_max_size, group=group self.vllm_max_size, group=group
) )
self.vllm_rank_data = torch.empty( self.vllm_rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
) )
self.vllm_ptr = vllm_ops.init_custom_ar( self.vllm_ptr = vllm_ops.init_custom_ar(
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
...@@ -171,8 +170,7 @@ class TestCustomAllReduce(unittest.TestCase): ...@@ -171,8 +170,7 @@ class TestCustomAllReduce(unittest.TestCase):
@staticmethod @staticmethod
def init_distributed_env(world_size, rank, distributed_init_port): def init_distributed_env(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device("cuda:0")
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
ranks = [i for i in range(world_size)] ranks = [i for i in range(world_size)]
distributed_init_method = f"tcp://localhost:{distributed_init_port}" distributed_init_method = f"tcp://localhost:{distributed_init_port}"
...@@ -234,8 +232,8 @@ class TestCustomAllReduce(unittest.TestCase): ...@@ -234,8 +232,8 @@ class TestCustomAllReduce(unittest.TestCase):
if rank == 0: if rank == 0:
logger.warning( logger.warning(
f"test_size = {sz}, world_size = {world_size}, " f"test_size = {sz}, world_size = {world_size}, "
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms," f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, "
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms" f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms "
) )
self.free_custom_allreduce(group) self.free_custom_allreduce(group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment