Unverified Commit cfe48c59 authored by blzheng's avatar blzheng Committed by GitHub
Browse files

[CPU] Fix build issue (#6419)

parent d4c038da
...@@ -5,9 +5,7 @@ set(CMAKE_CXX_STANDARD 17) ...@@ -5,9 +5,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
# Torch find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
find_package(Torch REQUIRED)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
execute_process( execute_process(
COMMAND ${Python_EXECUTABLE} COMMAND ${Python_EXECUTABLE}
...@@ -23,8 +21,9 @@ find_package(Torch REQUIRED) ...@@ -23,8 +21,9 @@ find_package(Torch REQUIRED)
include_directories( include_directories(
${TORCH_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS}
${TORCH_INSTALL_PREFIX}/include ${TORCH_INSTALL_PREFIX}/include
${Python3_INCLUDE_DIRS} ${Python_INCLUDE_DIRS}
${CMAKE_SOURCE_DIR}/csrc ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc
${CMAKE_CURRENT_SOURCE_DIR}/../../include
) )
# Platform-specific library directory # Platform-specific library directory
...@@ -39,23 +38,7 @@ else() ...@@ -39,23 +38,7 @@ else()
endif() endif()
link_directories(${PLAT_LIB_DIR}) link_directories(${PLAT_LIB_DIR})
set(SOURCES file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp")
${CMAKE_CURRENT_SOURCE_DIR}/activation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/decode.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extend.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm_int8.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moe.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moe_int8.cpp
${CMAKE_CURRENT_SOURCE_DIR}/norm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qkv_proj.cpp
${CMAKE_CURRENT_SOURCE_DIR}/topk.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/interface.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/torch_extension_cpu.cpp
)
add_compile_options( add_compile_options(
-O3 -O3
...@@ -64,24 +47,10 @@ add_compile_options( ...@@ -64,24 +47,10 @@ add_compile_options(
-fopenmp -fopenmp
) )
add_library(sgl_kernel_common_ops SHARED ${SOURCES}) Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(sgl_kernel_common_ops target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
PRIVATE
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
c10
)
set_target_properties(sgl_kernel_common_ops PROPERTIES
INSTALL_RPATH "$ORIGIN/../../torch/lib"
PREFIX ""
OUTPUT_NAME "sgl_kernel.common_ops"
)
target_compile_definitions(sgl_kernel_common_ops PRIVATE TORCH_API_INCLUDE_EXTENSION_H)
# Install install(TARGETS common_ops
install(TARGETS sgl_kernel_common_ops LIBRARY DESTINATION sgl_kernel
LIBRARY DESTINATION ${Python3_SITEARCH}
) )
...@@ -74,7 +74,8 @@ void bmm_kernel_impl( ...@@ -74,7 +74,8 @@ void bmm_kernel_impl(
// out : [B, M, N] // out : [B, M, N]
// scale: [] 0-dim tensor for per tensor quant // scale: [] 0-dim tensor for per tensor quant
// //
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale) { void bmm_cpu(
at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale) {
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2})); RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
......
...@@ -463,7 +463,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { ...@@ -463,7 +463,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) {
// bias : [N] // bias : [N]
// out : [M, N] // out : [M, N]
// //
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni) { at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias})); RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
......
...@@ -482,7 +482,7 @@ at::Tensor fp8_scaled_mm_cpu( ...@@ -482,7 +482,7 @@ at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat2, at::Tensor& mat2,
at::Tensor& scales2, at::Tensor& scales2,
std::vector<int64_t> block_size, std::vector<int64_t> block_size,
std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, at::ScalarType out_dtype,
bool is_vnni) { bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias})); RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
......
...@@ -366,7 +366,7 @@ at::Tensor int8_scaled_mm_cpu( ...@@ -366,7 +366,7 @@ at::Tensor int8_scaled_mm_cpu(
at::Tensor& mat2, at::Tensor& mat2,
at::Tensor& scales1, at::Tensor& scales1,
at::Tensor& scales2, at::Tensor& scales2,
std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, at::ScalarType out_dtype,
bool is_vnni) { bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias})); RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
...@@ -424,7 +424,7 @@ at::Tensor int8_scaled_mm_with_quant( ...@@ -424,7 +424,7 @@ at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1, at::Tensor& mat1,
at::Tensor& mat2, at::Tensor& mat2,
at::Tensor& scales2, at::Tensor& scales2,
std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, at::ScalarType out_dtype,
bool is_vnni) { bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias})); RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
......
...@@ -11,7 +11,7 @@ static bool is_initialized = false; ...@@ -11,7 +11,7 @@ static bool is_initialized = false;
static bool all_ranks_local_p = false; static bool all_ranks_local_p = false;
void initialize(int size, int rank) { void initialize(int64_t size, int64_t rank) {
if (is_initialized) { if (is_initialized) {
return; return;
} }
...@@ -47,12 +47,11 @@ void initialize(int size, int rank) { ...@@ -47,12 +47,11 @@ void initialize(int size, int rank) {
} }
} }
void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op) { void shm_allreduce(
torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op) {
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data})); RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
static py::object ReduceOp = py::module_::import("torch.distributed").attr("ReduceOp"); TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
TORCH_CHECK(py::int_(op.attr("value")) == ReduceOpSum, "Only torch.distributed.ReduceOp.SUM is supported");
auto numel = data.numel(); auto numel = data.numel();
...@@ -81,7 +80,7 @@ void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> p ...@@ -81,7 +80,7 @@ void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> p
return; return;
} }
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim) { torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim) {
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data})); RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
auto numel = data.numel(); auto numel = data.numel();
......
...@@ -946,10 +946,10 @@ at::Tensor fused_experts_cpu( ...@@ -946,10 +946,10 @@ at::Tensor fused_experts_cpu(
at::Tensor& topk_ids, at::Tensor& topk_ids,
bool inplace, bool inplace,
bool use_int8_w8a8, bool use_int8_w8a8,
std::optional<at::Tensor>& w1_scale, const std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale, const std::optional<at::Tensor>& w2_scale,
std::optional<at::Tensor>& a1_scale, const std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale, const std::optional<at::Tensor>& a2_scale,
bool is_vnni) { bool is_vnni) {
RECORD_FUNCTION( RECORD_FUNCTION(
"sgl-kernel::fused_experts_cpu", std::vector<c10::IValue>({hidden_states, w1, w2, topk_weights, topk_ids})); "sgl-kernel::fused_experts_cpu", std::vector<c10::IValue>({hidden_states, w1, w2, topk_weights, topk_ids}));
...@@ -1138,11 +1138,11 @@ at::Tensor shared_expert_cpu( ...@@ -1138,11 +1138,11 @@ at::Tensor shared_expert_cpu(
bool inplace, bool inplace,
bool use_int8_w8a8, bool use_int8_w8a8,
bool use_fp8_w8a16, bool use_fp8_w8a16,
std::optional<at::Tensor>& w1_scale, const std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale, const std::optional<at::Tensor>& w2_scale,
std::optional<std::vector<int64_t>> block_size, const std::optional<std::vector<int64_t>> block_size,
std::optional<at::Tensor>& a1_scale, const std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale, const std::optional<at::Tensor>& a2_scale,
bool is_vnni) { bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector<c10::IValue>({hidden_states, w1, w2})); RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector<c10::IValue>({hidden_states, w1, w2}));
......
...@@ -308,18 +308,18 @@ void rotary_emb_kernel_impl( ...@@ -308,18 +308,18 @@ void rotary_emb_kernel_impl(
} // anonymous namespace } // anonymous namespace
extern at::Tensor extern at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni); weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
extern at::Tensor int8_scaled_mm_with_quant( extern at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1, at::Tensor& mat1,
at::Tensor& mat2, at::Tensor& mat2,
at::Tensor& scales2, at::Tensor& scales2,
std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, at::ScalarType out_dtype,
bool is_vnni); bool is_vnni);
extern void extern void
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale); bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
// NB: shapes in DeepDeek R1 // NB: shapes in DeepDeek R1
// //
...@@ -343,9 +343,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope( ...@@ -343,9 +343,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& cos_sin_cache, at::Tensor& cos_sin_cache,
double eps, double eps,
bool use_int8_w8a8, bool use_int8_w8a8,
std::optional<at::Tensor>& q_a_proj_scale, std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor>& q_b_proj_scale, std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor>& kv_a_proj_scale, std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni) { bool is_vnni) {
RECORD_FUNCTION( RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope", "sgl-kernel::qkv_proj_with_rope",
......
#include <torch/torch.h> #include <torch/all.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp> #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
......
...@@ -17,6 +17,7 @@ limitations under the License. ...@@ -17,6 +17,7 @@ limitations under the License.
#include <torch/all.h> #include <torch/all.h>
#include <torch/library.h> #include <torch/library.h>
#include "sgl_kernel_ops.h"
#include "shm.h" #include "shm.h"
// silu_and_mul // silu_and_mul
...@@ -85,7 +86,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight); ...@@ -85,7 +86,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight);
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A); std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);
// gemm // gemm
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni); at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
// igemm // igemm
at::Tensor int8_scaled_mm_cpu( at::Tensor int8_scaled_mm_cpu(
...@@ -93,7 +95,7 @@ at::Tensor int8_scaled_mm_cpu( ...@@ -93,7 +95,7 @@ at::Tensor int8_scaled_mm_cpu(
at::Tensor& mat2, at::Tensor& mat2,
at::Tensor& scales1, at::Tensor& scales1,
at::Tensor& scales2, at::Tensor& scales2,
std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, at::ScalarType out_dtype,
bool is_vnni); bool is_vnni);
...@@ -103,7 +105,7 @@ at::Tensor fp8_scaled_mm_cpu( ...@@ -103,7 +105,7 @@ at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat2, at::Tensor& mat2,
at::Tensor& scales2, at::Tensor& scales2,
std::vector<int64_t> block_size, std::vector<int64_t> block_size,
std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, at::ScalarType out_dtype,
bool is_vnni); bool is_vnni);
...@@ -112,12 +114,12 @@ at::Tensor int8_scaled_mm_with_quant( ...@@ -112,12 +114,12 @@ at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1, at::Tensor& mat1,
at::Tensor& mat2, at::Tensor& mat2,
at::Tensor& scales2, at::Tensor& scales2,
std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, at::ScalarType out_dtype,
bool is_vnni); bool is_vnni);
// bmm // bmm
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale); void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
// fused moe // fused moe
at::Tensor fused_experts_cpu( at::Tensor fused_experts_cpu(
...@@ -128,10 +130,10 @@ at::Tensor fused_experts_cpu( ...@@ -128,10 +130,10 @@ at::Tensor fused_experts_cpu(
at::Tensor& topk_ids, at::Tensor& topk_ids,
bool inplace, bool inplace,
bool use_int8_w8a8, bool use_int8_w8a8,
std::optional<at::Tensor>& w1_scale, const std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale, const std::optional<at::Tensor>& w2_scale,
std::optional<at::Tensor>& a1_scale, const std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale, const std::optional<at::Tensor>& a2_scale,
bool is_vnni); bool is_vnni);
at::Tensor shared_expert_cpu( at::Tensor shared_expert_cpu(
...@@ -143,11 +145,11 @@ at::Tensor shared_expert_cpu( ...@@ -143,11 +145,11 @@ at::Tensor shared_expert_cpu(
bool inplace, bool inplace,
bool use_int8_w8a8, bool use_int8_w8a8,
bool use_fp8_w8a16, bool use_fp8_w8a16,
std::optional<at::Tensor>& w1_scale, const std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale, const std::optional<at::Tensor>& w2_scale,
std::optional<std::vector<int64_t>> block_size, const std::optional<std::vector<int64_t>> block_size,
std::optional<at::Tensor>& a1_scale, const std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale, const std::optional<at::Tensor>& a2_scale,
bool is_vnni); bool is_vnni);
// weight absorption // weight absorption
...@@ -163,80 +165,130 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope( ...@@ -163,80 +165,130 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& cos_sin_cache, at::Tensor& cos_sin_cache,
double eps, double eps,
bool use_int8_w8a8, bool use_int8_w8a8,
std::optional<at::Tensor>& q_a_proj_scale, std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor>& q_b_proj_scale, std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor>& kv_a_proj_scale, std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni); bool is_vnni);
// shared memory init // shared memory init
void initialize(int size, int rank); void initialize(int64_t size, int64_t rank);
// shared mmeory all_reduce // shared mmeory all_reduce
void shm_allreduce(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op); void shm_allreduce(
at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
// shared memory all_gather // shared memory all_gather
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim); at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
// rope // rope
std::tuple<at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor>
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos); rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// activation // activation
m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU"); m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
// norm // norm
m.def("rmsnorm_cpu", &rmsnorm_cpu, "Root mean square normalization for CPU"); m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
m.def("fused_add_rmsnorm_cpu", &fused_add_rmsnorm_cpu, "Fused add root mean square normalization for CPU"); m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
// topk // topk
m.def("grouped_topk_cpu", &grouped_topk_cpu, "Grouped TopK for CPU"); m.def(
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
"int topk_group) -> (Tensor, Tensor)");
m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu);
// biased group topk // biased group topk
m.def("biased_grouped_topk_cpu", &biased_grouped_topk_cpu, "Biased Grouped TopK for CPU"); m.def(
"biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool "
"renormalize, int num_expert_group, int topk_group) -> (Tensor, Tensor)");
m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu);
// decode // decode
m.def("decode_attention_cpu", &decode_attention_cpu, "Attention decoding for CPU"); m.def(
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()");
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
// extend // extend
m.def("extend_attention_cpu", &extend_attention_cpu, "Attention extend for CPU"); m.def(
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
// weight prepack // weight prepack
m.def("convert_weight_packed", &convert_weight_packed, "prepack weight to vnni format for intel AMX"); m.def("convert_weight_packed(Tensor weight) -> Tensor");
m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
// quant // quant
m.def("per_token_quant_int8_cpu", &per_token_quant_int8_cpu, "dynamic quantization for CPU"); m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)");
m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu);
// gemm // gemm
m.def("weight_packed_linear", &weight_packed_linear, "weight packed linear for intel AMX"); m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor");
m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
// igemm // igemm
m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX"); m.def(
"int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType "
"out_dtype, bool is_vnni) -> Tensor");
m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu);
// fp8 gemm // fp8 gemm
m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX"); m.def(
"fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType "
"out_dtype, bool is_vnni) -> Tensor");
m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
// quant + igemm // quant + igemm
m.def( m.def(
"int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX"); "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool "
"is_vnni) -> Tensor");
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
// bmm // bmm
m.def("bmm_cpu", &bmm_cpu, "bmm kernel for intel AMX"); m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
// moe // moe
m.def("fused_experts_cpu", &fused_experts_cpu, "fused moe kernel for CPU"); m.def(
"fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
"inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool "
"is_vnni) -> Tensor");
m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
// weight absorption // weight absorption
m.def("qkv_proj_with_rope", &qkv_proj_with_rope, "fused qkv projection kernel with weight absorption for intel AMX"); m.def(
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)");
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
// shared expert // shared expert
m.def("shared_expert_cpu", &shared_expert_cpu, "shared expert kernel for CPU"); m.def(
"shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float "
"routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? "
"w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor");
m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu);
// all reduce // all reduce
m.def("initialize", &initialize, "shared memory initialization for CPU"); m.def("initialize(int size, int rank) -> ()");
m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU"); m.impl("initialize", torch::kCPU, &initialize);
m.def("shm_allgather", &shm_allgather, "low latency all_gather implementation for CPU"); m.def(
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor");
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
// rope // rope
m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU"); m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)");
m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu);
} }
REGISTER_EXTENSION(common_ops)
...@@ -34,7 +34,3 @@ exclude = [ ...@@ -34,7 +34,3 @@ exclude = [
cmake.source-dir = "csrc/cpu" cmake.source-dir = "csrc/cpu"
cmake.build-type = "Release" cmake.build-type = "Release"
minimum-version = "build-system.requires" minimum-version = "build-system.requires"
wheel.py-api = "cp39"
wheel.license-files = []
wheel.packages = ["python/sgl_kernel"]
...@@ -50,7 +50,9 @@ def _get_version(): ...@@ -50,7 +50,9 @@ def _get_version():
cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1" cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1"
operator_namespace = "sgl_kernel" operator_namespace = "sgl_kernel"
include_dirs = [] include_dirs = [
"../../include",
]
sources = [ sources = [
"csrc/cpu/activation.cpp", "csrc/cpu/activation.cpp",
...@@ -99,7 +101,7 @@ ext_modules = [ ...@@ -99,7 +101,7 @@ ext_modules = [
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
libraries=libraries, libraries=libraries,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
py_limited_api=True, py_limited_api=False,
), ),
] ]
......
import itertools import itertools
import unittest import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch import torch
import torch.nn as nn import torch.nn as nn
# TODO: use interface in cpu.py
from sgl_kernel.common_ops import (
convert_weight_packed,
fp8_scaled_mm_cpu,
int8_scaled_mm_cpu,
int8_scaled_mm_with_quant,
per_token_quant_int8_cpu,
weight_packed_linear,
)
from utils import ( from utils import (
convert_weight, convert_weight,
native_w8a8_per_token_matmul, native_w8a8_per_token_matmul,
...@@ -58,10 +50,14 @@ class TestGemm(CustomTestCase): ...@@ -58,10 +50,14 @@ class TestGemm(CustomTestCase):
ref = ref.bfloat16() ref = ref.bfloat16()
out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False) out = torch.ops.sgl_kernel.weight_packed_linear(
mat1, mat2, bias if has_bias else None, False
)
packed_mat2 = convert_weight_packed(mat2) packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True) out2 = torch.ops.sgl_kernel.weight_packed_linear(
mat1, packed_mat2, bias if has_bias else None, True
)
atol = rtol = precision[ref.dtype] atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
...@@ -100,14 +96,14 @@ class TestGemm(CustomTestCase): ...@@ -100,14 +96,14 @@ class TestGemm(CustomTestCase):
atol = rtol = precision[ref_out.dtype] atol = rtol = precision[ref_out.dtype]
Aq2, As2 = per_token_quant_int8_cpu(A) Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
out = int8_scaled_mm_cpu( out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
) )
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
# test the fused version # test the fused version
fused_out = int8_scaled_mm_with_quant( fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
) )
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
...@@ -157,9 +153,9 @@ class TestGemm(CustomTestCase): ...@@ -157,9 +153,9 @@ class TestGemm(CustomTestCase):
ref = torch.matmul(data.to(A_dtype), dq_weight.T) ref = torch.matmul(data.to(A_dtype), dq_weight.T)
if prepack: if prepack:
fp8_weight = convert_weight_packed(fp8_weight) fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
opt = fp8_scaled_mm_cpu( opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
data, data,
fp8_weight, fp8_weight,
scales, scales,
......
...@@ -2,12 +2,10 @@ import itertools ...@@ -2,12 +2,10 @@ import itertools
import math import math
import unittest import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch import torch
import torch.nn as nn import torch.nn as nn
# TODO: use interface in cpu.py
from sgl_kernel.common_ops import convert_weight_packed
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
from utils import ( from utils import (
BLOCK_K, BLOCK_K,
BLOCK_N, BLOCK_N,
...@@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase): ...@@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase):
fused_output.float(), fused_output.float(),
routed_scaling_factor, routed_scaling_factor,
).to(dtype=dtype) ).to(dtype=dtype)
res = shared_expert( res = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states, hidden_states,
w1, w1,
w2, w2,
...@@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase): ...@@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase):
fused_output.float(), fused_output.float(),
routed_scaling_factor, routed_scaling_factor,
).to(dtype=dtype) ).to(dtype=dtype)
res2 = shared_expert( res2 = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states2, hidden_states2,
w1_q, w1_q,
w2_q, w2_q,
...@@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase): ...@@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase):
ref_out = shared_out + fused_out.float() * routed_scaling_factor ref_out = shared_out + fused_out.float() * routed_scaling_factor
ref_out = ref_out.to(dtype=dtype) ref_out = ref_out.to(dtype=dtype)
w1 = convert_weight_packed(w1) # [2N, K] w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
w2 = convert_weight_packed(w2) # [K, N] w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
out = shared_expert( out = torch.ops.sgl_kernel.shared_expert_cpu(
a2, a2,
w1, w1,
w2, w2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment