Commit 68c2b2e6 authored by djw's avatar djw
Browse files

support qwen3

parent 0da3792b
......@@ -53,6 +53,21 @@ else ()
set(CMAKE_GENERATOR_PLATFORM_LWR "")
endif ()
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
find_package(Python3 REQUIRED COMPONENTS Interpreter)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
OUTPUT_VARIABLE ABI_FLAG
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE)
endif()
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
if (NOT MSVC)
if (LLAMA_STATIC)
add_link_options(-static)
......@@ -115,6 +130,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
message(STATUS "x86 detected")
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(HAS_AMX TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
set(HAS_AVX512 False)
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(HAS_AMX)
else()
message(STATUS "Compiler does NOT support AMX")
endif()
if (MSVC)
# instruction set detection for MSVC only
if (LLAMA_NATIVE)
......@@ -293,7 +340,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
if (HOST_IS_X86 AND HAS_AVX512 AND HAS_AMX)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
endif()
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
......
......@@ -17,7 +17,11 @@
#include "operators/llamafile/linear.h"
#include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.h"
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
#include "operators/amx/moe.hpp"
#endif
#include "pybind11/functional.h"
#include "pybind11/operators.h"
#include "pybind11/pybind11.h"
......@@ -564,6 +568,8 @@ class MOEBindings {
};
};
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
template<class T>
class AMX_MOEBindings {
public:
......@@ -632,6 +638,7 @@ class AMX_MOEBindings {
}
};
};
#endif
PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_<CPUInfer>(m, "CPUInfer")
......@@ -691,6 +698,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
py::class_<AMX_MOEConfig>(moe_module, "AMX_MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size,
......@@ -701,6 +710,7 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
max_len, (void *)gate_proj,
(void *)up_proj, (void *)down_proj);
}));
py::class_<AMX_MOE<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE")
.def(py::init<AMX_MOEConfig>())
.def("warm_up", &AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface)
......@@ -712,6 +722,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.def("load_weights", &AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface)
.def("forward", &AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface);
#endif
auto kvcache_module = m.def_submodule("kvcache");
py::enum_<AnchorType>(kvcache_module, "AnchorType")
......
......@@ -56,7 +56,7 @@
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
backend: "AMXBF16" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
......
......@@ -146,7 +146,7 @@ async def main(concurrent_requests , prompt, max_tokens, model):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True)
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name")
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment