Commit 68c2b2e6 authored by djw's avatar djw
Browse files

support qwen3

parent 0da3792b
...@@ -53,6 +53,21 @@ else () ...@@ -53,6 +53,21 @@ else ()
set(CMAKE_GENERATOR_PLATFORM_LWR "") set(CMAKE_GENERATOR_PLATFORM_LWR "")
endif () endif ()
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
find_package(Python3 REQUIRED COMPONENTS Interpreter)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
OUTPUT_VARIABLE ABI_FLAG
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE)
endif()
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
if (NOT MSVC) if (NOT MSVC)
if (LLAMA_STATIC) if (LLAMA_STATIC)
add_link_options(-static) add_link_options(-static)
...@@ -115,6 +130,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW ...@@ -115,6 +130,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$")) CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
message(STATUS "x86 detected") message(STATUS "x86 detected")
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(HAS_AMX TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
set(HAS_AVX512 False)
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(HAS_AMX)
else()
message(STATUS "Compiler does NOT support AMX")
endif()
if (MSVC) if (MSVC)
# instruction set detection for MSVC only # instruction set detection for MSVC only
if (LLAMA_NATIVE) if (LLAMA_NATIVE)
...@@ -293,7 +340,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2) ...@@ -293,7 +340,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
if (HOST_IS_X86 AND HAS_AVX512 AND HAS_AMX)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
endif()
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6}) set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
......
...@@ -17,7 +17,11 @@ ...@@ -17,7 +17,11 @@
#include "operators/llamafile/linear.h" #include "operators/llamafile/linear.h"
#include "operators/llamafile/mlp.h" #include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.h" #include "operators/llamafile/moe.h"
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
#include "operators/amx/moe.hpp" #include "operators/amx/moe.hpp"
#endif
#include "pybind11/functional.h" #include "pybind11/functional.h"
#include "pybind11/operators.h" #include "pybind11/operators.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -564,6 +568,8 @@ class MOEBindings { ...@@ -564,6 +568,8 @@ class MOEBindings {
}; };
}; };
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
template<class T> template<class T>
class AMX_MOEBindings { class AMX_MOEBindings {
public: public:
...@@ -632,6 +638,7 @@ class AMX_MOEBindings { ...@@ -632,6 +638,7 @@ class AMX_MOEBindings {
} }
}; };
}; };
#endif
PYBIND11_MODULE(cpuinfer_ext, m) { PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_<CPUInfer>(m, "CPUInfer") py::class_<CPUInfer>(m, "CPUInfer")
...@@ -691,6 +698,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) { ...@@ -691,6 +698,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface) .def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface); .def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
py::class_<AMX_MOEConfig>(moe_module, "AMX_MOEConfig") py::class_<AMX_MOEConfig>(moe_module, "AMX_MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, .def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size, int intermediate_size,
...@@ -701,6 +710,7 @@ PYBIND11_MODULE(cpuinfer_ext, m) { ...@@ -701,6 +710,7 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
max_len, (void *)gate_proj, max_len, (void *)gate_proj,
(void *)up_proj, (void *)down_proj); (void *)up_proj, (void *)down_proj);
})); }));
py::class_<AMX_MOE<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE") py::class_<AMX_MOE<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE")
.def(py::init<AMX_MOEConfig>()) .def(py::init<AMX_MOEConfig>())
.def("warm_up", &AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface) .def("warm_up", &AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface)
...@@ -712,6 +722,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) { ...@@ -712,6 +722,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.def("load_weights", &AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface) .def("load_weights", &AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface)
.def("forward", &AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface); .def("forward", &AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface);
#endif
auto kvcache_module = m.def_submodule("kvcache"); auto kvcache_module = m.def_submodule("kvcache");
py::enum_<AnchorType>(kvcache_module, "AnchorType") py::enum_<AnchorType>(kvcache_module, "AnchorType")
......
...@@ -56,7 +56,7 @@ ...@@ -56,7 +56,7 @@
generate_device: "cpu" generate_device: "cpu"
generate_op: "KExpertsCPU" generate_op: "KExpertsCPU"
out_device: "cuda" out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default) backend: "AMXBF16" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\..*\\.self_attn$" name: "^model\\.layers\\..*\\.self_attn$"
......
...@@ -146,7 +146,7 @@ async def main(concurrent_requests , prompt, max_tokens, model): ...@@ -146,7 +146,7 @@ async def main(concurrent_requests , prompt, max_tokens, model):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True) parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name")
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment