Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
68c2b2e6
Commit
68c2b2e6
authored
Apr 28, 2025
by
djw
Browse files
support qwen3
parent
0da3792b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
3 deletions
+65
-3
csrc/ktransformers_ext/CMakeLists.txt
csrc/ktransformers_ext/CMakeLists.txt
+51
-1
csrc/ktransformers_ext/ext_bindings.cpp
csrc/ktransformers_ext/ext_bindings.cpp
+12
-0
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
+1
-1
ktransformers/tests/test_speed.py
ktransformers/tests/test_speed.py
+1
-1
No files found.
csrc/ktransformers_ext/CMakeLists.txt
View file @
68c2b2e6
...
@@ -53,6 +53,21 @@ else ()
...
@@ -53,6 +53,21 @@ else ()
set
(
CMAKE_GENERATOR_PLATFORM_LWR
""
)
set
(
CMAKE_GENERATOR_PLATFORM_LWR
""
)
endif
()
endif
()
if
(
NOT DEFINED _GLIBCXX_USE_CXX11_ABI
)
find_package
(
Python3 REQUIRED COMPONENTS Interpreter
)
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
-c
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
OUTPUT_VARIABLE ABI_FLAG
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set
(
_GLIBCXX_USE_CXX11_ABI
${
ABI_FLAG
}
CACHE STRING
"C++11 ABI setting from PyTorch"
FORCE
)
endif
()
add_compile_definitions
(
_GLIBCXX_USE_CXX11_ABI=
${
_GLIBCXX_USE_CXX11_ABI
}
)
if
(
NOT MSVC
)
if
(
NOT MSVC
)
if
(
LLAMA_STATIC
)
if
(
LLAMA_STATIC
)
add_link_options
(
-static
)
add_link_options
(
-static
)
...
@@ -115,6 +130,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
...
@@ -115,6 +130,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
(
NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
(
NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES
"^(x86_64|i686|AMD64)$"
))
CMAKE_SYSTEM_PROCESSOR MATCHES
"^(x86_64|i686|AMD64)$"
))
message
(
STATUS
"x86 detected"
)
message
(
STATUS
"x86 detected"
)
set
(
HOST_IS_X86 TRUE
)
set
(
HAS_AVX512 TRUE
)
set
(
HAS_AMX TRUE
)
add_compile_definitions
(
__x86_64__
)
# check AVX512
execute_process
(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string
(
FIND
"
${
LSCPU_OUTPUT
}
"
"avx512"
COMPILER_SUPPORTS_AVX512F
)
if
(
COMPILER_SUPPORTS_AVX512F GREATER -1
)
message
(
STATUS
"Compiler and CPU support AVX512F (tested by compiling a program)"
)
add_compile_definitions
(
__HAS_AVX512F__
)
else
()
message
(
STATUS
"Compiler and/or CPU do NOT support AVX512F"
)
set
(
HAS_AVX512 False
)
endif
()
set
(
HAS_AVX512 False
)
# check AMX
string
(
FIND
"
${
LSCPU_OUTPUT
}
"
"amx"
COMPILER_SUPPORTS_AMX
)
if
(
COMPILER_SUPPORTS_AMX GREATER -1
)
message
(
STATUS
"Compiler supports AMX"
)
add_compile_definitions
(
HAS_AMX
)
else
()
message
(
STATUS
"Compiler does NOT support AMX"
)
endif
()
if
(
MSVC
)
if
(
MSVC
)
# instruction set detection for MSVC only
# instruction set detection for MSVC only
if
(
LLAMA_NATIVE
)
if
(
LLAMA_NATIVE
)
...
@@ -293,7 +340,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
...
@@ -293,7 +340,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/operators/llamafile SOURCE_DIR3
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/operators/llamafile SOURCE_DIR3
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../third_party/llamafile SOURCE_DIR4
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../third_party/llamafile SOURCE_DIR4
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/operators/kvcache SOURCE_DIR5
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/operators/kvcache SOURCE_DIR5
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/operators/amx SOURCE_DIR6
)
if
(
HOST_IS_X86 AND HAS_AVX512 AND HAS_AMX
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/operators/amx SOURCE_DIR6
)
endif
()
set
(
ALL_SOURCES
${
SOURCE_DIR1
}
${
SOURCE_DIR2
}
${
SOURCE_DIR3
}
${
SOURCE_DIR4
}
${
SOURCE_DIR5
}
${
SOURCE_DIR6
}
)
set
(
ALL_SOURCES
${
SOURCE_DIR1
}
${
SOURCE_DIR2
}
${
SOURCE_DIR3
}
${
SOURCE_DIR4
}
${
SOURCE_DIR5
}
${
SOURCE_DIR6
}
)
...
...
csrc/ktransformers_ext/ext_bindings.cpp
View file @
68c2b2e6
...
@@ -17,7 +17,11 @@
...
@@ -17,7 +17,11 @@
#include "operators/llamafile/linear.h"
#include "operators/llamafile/linear.h"
#include "operators/llamafile/mlp.h"
#include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.h"
#include "operators/llamafile/moe.h"
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
#include "operators/amx/moe.hpp"
#include "operators/amx/moe.hpp"
#endif
#include "pybind11/functional.h"
#include "pybind11/functional.h"
#include "pybind11/operators.h"
#include "pybind11/operators.h"
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
...
@@ -564,6 +568,8 @@ class MOEBindings {
...
@@ -564,6 +568,8 @@ class MOEBindings {
};
};
};
};
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
template
<
class
T
>
template
<
class
T
>
class
AMX_MOEBindings
{
class
AMX_MOEBindings
{
public:
public:
...
@@ -632,6 +638,7 @@ class AMX_MOEBindings {
...
@@ -632,6 +638,7 @@ class AMX_MOEBindings {
}
}
};
};
};
};
#endif
PYBIND11_MODULE
(
cpuinfer_ext
,
m
)
{
PYBIND11_MODULE
(
cpuinfer_ext
,
m
)
{
py
::
class_
<
CPUInfer
>
(
m
,
"CPUInfer"
)
py
::
class_
<
CPUInfer
>
(
m
,
"CPUInfer"
)
...
@@ -691,6 +698,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
...
@@ -691,6 +698,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.
def
(
"warm_up"
,
&
MOEBindings
::
WarmUpBindinds
::
cpuinfer_interface
)
.
def
(
"warm_up"
,
&
MOEBindings
::
WarmUpBindinds
::
cpuinfer_interface
)
.
def
(
"forward"
,
&
MOEBindings
::
ForwardBindings
::
cpuinfer_interface
);
.
def
(
"forward"
,
&
MOEBindings
::
ForwardBindings
::
cpuinfer_interface
);
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
py
::
class_
<
AMX_MOEConfig
>
(
moe_module
,
"AMX_MOEConfig"
)
py
::
class_
<
AMX_MOEConfig
>
(
moe_module
,
"AMX_MOEConfig"
)
.
def
(
py
::
init
([](
int
expert_num
,
int
routed_expert_num
,
int
hidden_size
,
.
def
(
py
::
init
([](
int
expert_num
,
int
routed_expert_num
,
int
hidden_size
,
int
intermediate_size
,
int
intermediate_size
,
...
@@ -701,6 +710,7 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
...
@@ -701,6 +710,7 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
max_len
,
(
void
*
)
gate_proj
,
max_len
,
(
void
*
)
gate_proj
,
(
void
*
)
up_proj
,
(
void
*
)
down_proj
);
(
void
*
)
up_proj
,
(
void
*
)
down_proj
);
}));
}));
py
::
class_
<
AMX_MOE
<
amx
::
GemmKernel224BF
>>
(
moe_module
,
"AMXBF16_MOE"
)
py
::
class_
<
AMX_MOE
<
amx
::
GemmKernel224BF
>>
(
moe_module
,
"AMXBF16_MOE"
)
.
def
(
py
::
init
<
AMX_MOEConfig
>
())
.
def
(
py
::
init
<
AMX_MOEConfig
>
())
.
def
(
"warm_up"
,
&
AMX_MOEBindings
<
amx
::
GemmKernel224BF
>::
WarmUpBindings
::
cpuinfer_interface
)
.
def
(
"warm_up"
,
&
AMX_MOEBindings
<
amx
::
GemmKernel224BF
>::
WarmUpBindings
::
cpuinfer_interface
)
...
@@ -712,6 +722,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
...
@@ -712,6 +722,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.
def
(
"load_weights"
,
&
AMX_MOEBindings
<
amx
::
GemmKernel224Int8
>::
LoadWeightsBindings
::
cpuinfer_interface
)
.
def
(
"load_weights"
,
&
AMX_MOEBindings
<
amx
::
GemmKernel224Int8
>::
LoadWeightsBindings
::
cpuinfer_interface
)
.
def
(
"forward"
,
&
AMX_MOEBindings
<
amx
::
GemmKernel224Int8
>::
ForwardBindings
::
cpuinfer_interface
);
.
def
(
"forward"
,
&
AMX_MOEBindings
<
amx
::
GemmKernel224Int8
>::
ForwardBindings
::
cpuinfer_interface
);
#endif
auto
kvcache_module
=
m
.
def_submodule
(
"kvcache"
);
auto
kvcache_module
=
m
.
def_submodule
(
"kvcache"
);
py
::
enum_
<
AnchorType
>
(
kvcache_module
,
"AnchorType"
)
py
::
enum_
<
AnchorType
>
(
kvcache_module
,
"AnchorType"
)
...
...
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
View file @
68c2b2e6
...
@@ -56,7 +56,7 @@
...
@@ -56,7 +56,7 @@
generate_device
:
"
cpu"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
out_device
:
"
cuda"
backend
:
"
AMX
Int8
"
# or "AMXBF16" or "llamafile" (default)
backend
:
"
AMX
BF16
"
# or "AMXBF16" or "llamafile" (default)
recursive
:
False
# don't recursively inject submodules of this module
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
...
...
ktransformers/tests/test_speed.py
View file @
68c2b2e6
...
@@ -146,7 +146,7 @@ async def main(concurrent_requests , prompt, max_tokens, model):
...
@@ -146,7 +146,7 @@ async def main(concurrent_requests , prompt, max_tokens, model):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Event Stream Request Tester"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Event Stream Request Tester"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1
,
help
=
"Number of concurrent requests"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1
,
help
=
"Number of concurrent requests"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"DeepSeek-V3"
,
help
=
"Model name"
,
required
=
True
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"DeepSeek-V3"
,
help
=
"Model name"
)
parser
.
add_argument
(
"--prompt_lens"
,
type
=
int
,
default
=
1024
,
help
=
"prefill prompt lens, 1024 or 2048"
)
parser
.
add_argument
(
"--prompt_lens"
,
type
=
int
,
default
=
1024
,
help
=
"prefill prompt lens, 1024 or 2048"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
parser
.
add_argument
(
"--max_tokens"
,
type
=
int
,
default
=
50
,
help
=
"max decode tokens"
)
parser
.
add_argument
(
"--max_tokens"
,
type
=
int
,
default
=
50
,
help
=
"max decode tokens"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment