Unverified Commit ffb1f7bf authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1210 from kvcache-ai/support-amx-qwen

Support amx and qwen3
parents ba92cf1a 8f76c37d
---
BasedOnStyle: LLVM
ColumnLimit: 120 # 设置最大行宽为 100
IndentWidth: 2
---
...@@ -22,6 +22,9 @@ interface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified ...@@ -22,6 +22,9 @@ interface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified
Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features. Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.
<h2 id="Updates">🔥 Updates</h2> <h2 id="Updates">🔥 Updates</h2>
* **Apr 29, 2025**: Support AMX-Int8 and AMX-BF16([Tutorial](./doc/en/AMX.md)). Support Qwen3MoE
https://github.com/user-attachments/assets/14992126-5203-4855-acf3-d250acead6b2
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)). * **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)).
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)). * **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)).
......
...@@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int* locks // extra global storage for barrier synchronization int* locks // extra global storage for barrier synchronization
) { ) {
int prob_m = *prob_m_ptr; int prob_m = *prob_m_ptr;
prob_m = min(prob_m, 1024);
const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks); const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);
if(prob_m > 16 * thread_m_blocks) if(prob_m > 16 * thread_m_blocks)
prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks)); prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks));
......
...@@ -53,6 +53,21 @@ else () ...@@ -53,6 +53,21 @@ else ()
set(CMAKE_GENERATOR_PLATFORM_LWR "") set(CMAKE_GENERATOR_PLATFORM_LWR "")
endif () endif ()
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
find_package(Python3 REQUIRED COMPONENTS Interpreter)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
OUTPUT_VARIABLE ABI_FLAG
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE)
endif()
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
if (NOT MSVC) if (NOT MSVC)
if (LLAMA_STATIC) if (LLAMA_STATIC)
add_link_options(-static) add_link_options(-static)
...@@ -115,6 +130,37 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW ...@@ -115,6 +130,37 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$")) CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
message(STATUS "x86 detected") message(STATUS "x86 detected")
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(__HAS_AMX__ TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(__HAS_AMX__)
else()
message(STATUS "Compiler does NOT support AMX")
endif()
if (MSVC) if (MSVC)
# instruction set detection for MSVC only # instruction set detection for MSVC only
if (LLAMA_NATIVE) if (LLAMA_NATIVE)
...@@ -294,8 +340,12 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3 ...@@ -294,8 +340,12 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
if (HOST_IS_X86 AND HAS_AVX512 AND __HAS_AMX__)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
endif()
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5}) set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h") file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h")
......
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2025-04-25 18:28:12
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2025-04-25 18:28:12
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
expert_num = 8
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
n_routed_experts = 8
layer_num = 10
qlen = 1024
CPUInfer = cpuinfer_ext.CPUInfer(65)
warm_up_iter = 100
test_iter = 100
def bench_moe(quant_mode: str):
with torch.inference_mode(mode=True):
if quant_mode == "bf16":
bytes_per_elem = 2.000000
elif quant_mode == "int8":
bytes_per_elem = 1.000000
else:
assert(False)
moes = []
gate_projs = []
up_projs = []
down_projs = []
for _ in range(layer_num):
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.moe.AMX_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr())
if quant_mode == "bf16":
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
CPUInfer.submit(moe.load_weights())
CPUInfer.sync()
elif quant_mode == "int8":
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
CPUInfer.submit(moe.load_weights())
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous()
weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
# warm up
for i in range(warm_up_iter):
CPUInfer.submit(
moes[i % layer_num].forward(
qlen,
n_routed_experts,
expert_ids[i % layer_num].data_ptr(),
weights[i % layer_num].data_ptr(),
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr(),
qlen_tensor.data_ptr()
)
)
CPUInfer.sync()
# test
start = time.perf_counter()
for i in range(test_iter):
CPUInfer.submit(
moes[i % layer_num].forward(
qlen,
n_routed_experts,
expert_ids[i % layer_num].data_ptr(),
weights[i % layer_num].data_ptr(),
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr(),
qlen_tensor.data_ptr()
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('Flops: ', hidden_size * intermediate_size * qlen * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GFLOPS')
print('')
bench_moe("bf16")
bench_moe("int8")
...@@ -30,7 +30,8 @@ void SharedMemBuffer::alloc(void* object, std::vector<std::pair<void**, uint64_t ...@@ -30,7 +30,8 @@ void SharedMemBuffer::alloc(void* object, std::vector<std::pair<void**, uint64_t
if (buffer_) { if (buffer_) {
free(buffer_); free(buffer_);
} }
buffer_ = malloc(size); buffer_ = std::aligned_alloc(64, size);
size_ = size; size_ = size;
for (auto& obj_requests : hist_requests_) { for (auto& obj_requests : hist_requests_) {
for (auto& requests : obj_requests.second) { for (auto& requests : obj_requests.second) {
...@@ -52,4 +53,4 @@ void SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests) ...@@ -52,4 +53,4 @@ void SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests)
*(request.first) = (uint8_t*)buffer_ + offset; *(request.first) = (uint8_t*)buffer_ + offset;
offset += request.second; offset += request.second;
} }
} }
\ No newline at end of file
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-08-05 04:49:08
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-05 06:36:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_SHAREDMEMBUFFER_H
#define CPUINFER_SHAREDMEMBUFFER_H
#include <cstdint>
#include <cstdlib>
#include <map>
#include <vector>
class SharedMemBuffer {
public:
SharedMemBuffer();
~SharedMemBuffer();
void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);
void dealloc(void* object);
private:
void* buffer_;
uint64_t size_;
std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;
void arrange(std::vector<std::pair<void**, uint64_t>> requests);
};
static SharedMemBuffer shared_mem_buffer;
#endif
\ No newline at end of file
...@@ -17,6 +17,11 @@ ...@@ -17,6 +17,11 @@
#include "operators/llamafile/linear.h" #include "operators/llamafile/linear.h"
#include "operators/llamafile/mlp.h" #include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.h" #include "operators/llamafile/moe.h"
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
#include "operators/amx/moe.hpp"
#endif
#include "pybind11/functional.h" #include "pybind11/functional.h"
#include "pybind11/operators.h" #include "pybind11/operators.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -563,6 +568,78 @@ class MOEBindings { ...@@ -563,6 +568,78 @@ class MOEBindings {
}; };
}; };
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
template<class T>
class AMX_MOEBindings {
public:
class WarmUpBindings {
public:
struct Args {
CPUInfer *cpuinfer;
AMX_MOE<T> *moe;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&AMX_MOE<T>::warm_up, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {
Args *args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class LoadWeightsBindings {
public:
struct Args {
CPUInfer *cpuinfer;
AMX_MOE<T> *moe;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&AMX_MOE<T>::load_weights, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {
Args *args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
public:
struct Args {
CPUInfer *cpuinfer;
AMX_MOE<T> *moe;
int qlen;
int k;
const uint64_t *expert_ids;
const float *weights;
const void *input;
void *output;
int *batch_size_tensor;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(
&AMX_MOE<T>::forward, args_->moe, args_->qlen, args_->k,
args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(AMX_MOE<T> &moe, int qlen, int k, intptr_t expert_ids,
intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) {
Args *args = new Args{nullptr,
&moe,
qlen,
k,
(const uint64_t *)expert_ids,
(const float *)weights,
(const void *)input,
(void *)output,
(int *)batch_size_tensor};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
#endif
PYBIND11_MODULE(cpuinfer_ext, m) { PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_<CPUInfer>(m, "CPUInfer") py::class_<CPUInfer>(m, "CPUInfer")
.def(py::init<int>()) .def(py::init<int>())
...@@ -621,6 +698,32 @@ PYBIND11_MODULE(cpuinfer_ext, m) { ...@@ -621,6 +698,32 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface) .def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface); .def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
py::class_<AMX_MOEConfig>(moe_module, "AMX_MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size,
int max_len, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj) {
return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,
intermediate_size,
max_len, (void *)gate_proj,
(void *)up_proj, (void *)down_proj);
}));
py::class_<AMX_MOE<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE")
.def(py::init<AMX_MOEConfig>())
.def("warm_up", &AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface)
.def("load_weights", &AMX_MOEBindings<amx::GemmKernel224BF>::LoadWeightsBindings::cpuinfer_interface)
.def("forward", &AMX_MOEBindings<amx::GemmKernel224BF>::ForwardBindings::cpuinfer_interface);
py::class_<AMX_MOE<amx::GemmKernel224Int8>>(moe_module, "AMXInt8_MOE")
.def(py::init<AMX_MOEConfig>())
.def("warm_up", &AMX_MOEBindings<amx::GemmKernel224Int8>::WarmUpBindings::cpuinfer_interface)
.def("load_weights", &AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface)
.def("forward", &AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface);
#endif
auto kvcache_module = m.def_submodule("kvcache"); auto kvcache_module = m.def_submodule("kvcache");
py::enum_<AnchorType>(kvcache_module, "AnchorType") py::enum_<AnchorType>(kvcache_module, "AnchorType")
......
This diff is collapsed.
/**
* @Description :
* @Author : chenht2022
* @Date : 2025-04-25 18:28:12
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2025-04-25 18:28:12
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
#include <cstdint>
template <typename T>
T* offset_pointer(T* ptr, std::size_t byte_offset) {
return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr) + byte_offset);
}
template <typename T>
const T* offset_pointer(const T* ptr, std::size_t byte_offset) {
return reinterpret_cast<const T*>(reinterpret_cast<const char*>(ptr) + byte_offset);
}
template <typename T>
T* offset_pointer_row_major(T* t, int row, int col, std::size_t ld) {
return offset_pointer(t, row * ld) + col;
}
template <typename T>
T* offset_pointer_col_major(T* t, int row, int col, std::size_t ld) {
return offset_pointer(t, col * ld) + row;
}
static inline void avx512_copy_32xbf16(__m512i* src, __m512i* dst) {
_mm512_storeu_si512(dst, _mm512_loadu_si512(src));
}
static inline void avx512_32xfp32_to_32xbf16(__m512* src0, __m512* src1, __m512i* dst) {
_mm512_storeu_si512(dst, __m512i(_mm512_cvtne2ps_pbh(*src1, *src0)));
}
static inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512* dst1) {
_mm512_storeu_ps(dst0, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src))), 16)));
_mm512_storeu_ps(dst1, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src) + 1)), 16)));
}
\ No newline at end of file
/**
* @Description :
* @Author : chenht2022
* @Date : 2025-04-25 18:28:12
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2025-04-25 18:28:12
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_AMX_MOE_H
#define CPUINFER_OPERATOR_AMX_MOE_H
#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "../../cpu_backend/shared_mem_buffer.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "la/amx.hpp"
#ifdef USE_NUMA
#include <numa.h>
#include <numaif.h>
void *numa_alloc_aligned(size_t size, int node, size_t alignment) {
void *ptr = numa_alloc_onnode(size, node);
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
return ptr;
}
#endif
static inline __m512 exp_avx512(__m512 x) {
const __m512 log2e = _mm512_set1_ps(1.44269504089f);
const __m512 c1 = _mm512_set1_ps(0.69314718056f);
__m512 y = _mm512_mul_ps(x, log2e);
__m512i int_part = _mm512_cvtps_epi32(y);
__m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part));
const __m512 poly_1 = _mm512_set1_ps(0.9999999995f);
const __m512 poly_2 = _mm512_set1_ps(0.6931471805f);
const __m512 poly_3 = _mm512_set1_ps(0.2402265069f);
const __m512 poly_4 = _mm512_set1_ps(0.0555041087f);
const __m512 poly_5 = _mm512_set1_ps(0.0096181291f);
const __m512 poly_6 = _mm512_set1_ps(0.0013333558f);
__m512 frac_exp = _mm512_fmadd_ps(
frac_part, poly_6,
_mm512_fmadd_ps(frac_part, poly_5,
_mm512_fmadd_ps(frac_part, poly_4,
_mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1)))));
__m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part));
return _mm512_mul_ps(two_pow_i, frac_exp);
}
static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
__m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);
__m512 exp_neg_gate = exp_avx512(neg_gate_val);
__m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);
__m512 act_val = _mm512_div_ps(gate_val, denom);
return _mm512_mul_ps(act_val, up_val);
}
struct AMX_MOEConfig {
int expert_num;
int routed_expert_num;
int hidden_size;
int intermediate_size;
int max_len;
void *gate_proj;
void *up_proj;
void *down_proj;
AMX_MOEConfig() {}
AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,
void *gate_proj, void *up_proj, void *down_proj)
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),
intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),
down_proj(down_proj) {}
};
template <class T> class AMX_MOE {
private:
AMX_MOEConfig config_;
void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void *up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
ggml_bf16_t *m_local_input_; // [routed_expert_num * max_len * hidden_size]
ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size]
ggml_bf16_t *m_local_up_output_; // [routed_expert_num * max_len * intermediate_size]
ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size]
std::vector<std::vector<int>> m_local_pos_; // [max_len, routed_expert_num]
std::vector<int> m_local_num_; // [expert_num]
std::vector<int> m_expert_id_map_; // [expert_num]
std::vector<ggml_bf16_t *> m_local_input_ptr_; // [expert_num]
std::vector<ggml_bf16_t *> m_local_gate_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t *> m_local_up_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t *> m_local_down_output_ptr_; // [expert_num]
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
#ifdef USE_NUMA
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> gate_bb_numa_;
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> up_bb_numa_;
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> down_bb_numa_;
#else
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
#endif
public:
AMX_MOE(AMX_MOEConfig config) {
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
std::vector<std::pair<void **, uint64_t>> m_mem_requests;
m_mem_requests.push_back({(void **)&m_local_input_,
sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});
m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *
config_.max_len * config_.intermediate_size});
m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *
config_.max_len * config_.intermediate_size});
m_mem_requests.push_back({(void **)&m_local_down_output_,
sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});
std::vector<void *> gate_up_ba_ptr(config_.expert_num);
std::vector<void *> gate_bc_ptr(config_.expert_num);
std::vector<void *> up_bc_ptr(config_.expert_num);
std::vector<void *> down_ba_ptr(config_.expert_num);
std::vector<void *> down_bc_ptr(config_.expert_num);
for (int i = 0; i < config_.expert_num; i++) {
m_mem_requests.push_back(
{(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});
m_mem_requests.push_back(
{(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});
m_mem_requests.push_back(
{(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});
m_mem_requests.push_back(
{(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});
m_mem_requests.push_back(
{(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});
}
shared_mem_buffer.alloc(this, m_mem_requests);
m_local_pos_.resize(config_.max_len);
for (int i = 0; i < config_.max_len; i++) {
m_local_pos_[i].resize(config_.routed_expert_num);
}
m_expert_id_map_.resize(config_.expert_num);
m_local_num_.resize(config_.expert_num);
m_local_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
for (uint64_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i]));
gate_bc_.push_back(
std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, gate_bc_ptr[i]));
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, up_bc_ptr[i]));
down_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, down_ba_ptr[i]));
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, down_bc_ptr[i]));
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
gate_bb_numa_.resize(numa_nodes);
up_bb_numa_.resize(numa_nodes);
down_bb_numa_.resize(numa_nodes);
for (int j = 0; j < numa_nodes; j++) {
void *gate_bb_ptr =
numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);
gate_bb_numa_[j].push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
void *up_bb_ptr =
numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);
up_bb_numa_[j].push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
void *down_bb_ptr =
numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);
down_bb_numa_[j].push_back(
std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
}
#else
void *gate_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
gate_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
void *up_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
up_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
void *down_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
down_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
#endif
}
}
~AMX_MOE() { shared_mem_buffer.dealloc(this); }
void load_weights(Backend *backend) {
int nth = T::recommended_nth(config_.intermediate_size);
backend->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[&](int task_id) {
uint64_t expert_idx = task_id / nth;
int ith = task_id % nth;
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
for (int j = 0; j < numa_nodes; j++) {
gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +
expert_idx * config_.intermediate_size * config_.hidden_size,
ith, nth);
up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj +
expert_idx * config_.intermediate_size * config_.hidden_size,
ith, nth);
}
#else
gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +
expert_idx * config_.intermediate_size * config_.hidden_size,
ith, nth);
up_bb_[expert_idx]->from_mat(
(ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth);
#endif
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
backend->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[&](int task_id) {
uint64_t expert_idx = task_id / nth;
int ith = task_id % nth;
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
for (int j = 0; j < numa_nodes; j++) {
down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +
expert_idx * config_.hidden_size * config_.intermediate_size,
ith, nth);
}
#else
down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +
expert_idx * config_.hidden_size * config_.intermediate_size,
ith, nth);
#endif
},
nullptr);
}
void warm_up(Backend *backend) {}
void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output,
int *batch_size_tensor, Backend *backend) {
bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);
qlen = batch_size_tensor[0];
int activated_expert = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_num_[i] = 0;
}
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
for (int i = 0; i < config_.expert_num; i++) {
if (m_local_num_[i] > 0) {
m_expert_id_map_[activated_expert] = i;
activated_expert++;
}
}
uint64_t offset = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
backend->do_work_stealing_job(
qlen, nullptr,
[&](int i) {
for (int j = 0; j < k; j++) {
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
},
nullptr);
backend->do_work_stealing_job(
activated_expert, nullptr,
[&](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
},
nullptr);
int nth = T::recommended_nth(config_.intermediate_size);
backend->do_work_stealing_job(
nth * activated_expert, [&](int _) { T::config(); },
[&](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
#ifdef USE_NUMA
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx],
ith, nth, use_amx);
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith,
nth, use_amx);
#else
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);
#endif
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = act_fn(gate_val0, up_val0);
__m512 result1 = act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
}
}
},
nullptr);
backend->do_work_stealing_job(
activated_expert, nullptr,
[&](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
backend->do_work_stealing_job(
nth * activated_expert, [&](int _) { T::config(); },
[&](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
#ifdef USE_NUMA
amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],
down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx);
#else
amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx);
#endif
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
backend->do_work_stealing_job(
qlen, nullptr,
[&](int i) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e));
}
},
nullptr);
}
};
#endif
\ No newline at end of file
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
#include <vector> #include <vector>
#include "../../cpu_backend/backend.h" #include "../../cpu_backend/backend.h"
#include "../../cpu_backend/shared_mem_buffer.h"
#include "conversion.h" #include "conversion.h"
#include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h" #include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h" #include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct LinearConfig { struct LinearConfig {
int input_size; int input_size;
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
#include <vector> #include <vector>
#include "../../cpu_backend/backend.h" #include "../../cpu_backend/backend.h"
#include "../../cpu_backend/shared_mem_buffer.h"
#include "conversion.h" #include "conversion.h"
#include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h" #include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h" #include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct MLPConfig { struct MLPConfig {
int hidden_size; int hidden_size;
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
#include <vector> #include <vector>
#include "../../cpu_backend/backend.h" #include "../../cpu_backend/backend.h"
#include "../../cpu_backend/shared_mem_buffer.h"
#include "conversion.h" #include "conversion.h"
#include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h" #include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h" #include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct MOEConfig { struct MOEConfig {
int expert_num; int expert_num;
......
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-08-05 04:49:08
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-05 06:36:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_SHAREDMEMBUFFER_H
#define CPUINFER_SHAREDMEMBUFFER_H
#include <cstdint>
#include <cstdlib>
#include <map>
#include <vector>
class SharedMemBuffer {
public:
SharedMemBuffer();
~SharedMemBuffer();
void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);
void dealloc(void* object);
private:
void* buffer_;
uint64_t size_;
std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;
void arrange(std::vector<std::pair<void**, uint64_t>> requests);
};
static SharedMemBuffer shared_mem_buffer;
#endif
\ No newline at end of file
# Qwen 3 + KTransformers 0.3 (+AMX) = AI 工作站/PC
Following DeepSeek-V3/R1, LLaMa-4, and Kimi-VL, Qwen has also released an impressive MoE model—undoubtedly, this year belongs to MoE. As a low-barrier inference system for running MoE models in local heterogeneous environments, KTransformers naturally joins the party. Thanks to the support of the Qwen team, we completed Day 0 support for the entire Qwen 3 series of MoE models. At the same time, we took this opportunity to open-source the long-awaited preliminary version of our AMX high-performance operators (BF16, Int8; an int4 variant is coming soon), officially advancing to version 0.3.
What excites me most about Qwen3MoE is that, unlike the 671 B “giant” model, its two configurations—235 B-A22 and 30 B-A3B—hit the performance sweet spots for both local workstations and consumer-grade PCs. Accordingly, we ran benchmarks in two typical setups:
Server CPU (Xeon 4) + RTX 4090
Consumer-grade CPU (Core i9-14900KF + dual-channel DDR4-4000 MT/s) + RTX 4090
The results are as follows:
https://github.com/user-attachments/assets/14992126-5203-4855-acf3-d250acead6b2
Machine | Model | GPU Memory | RAM Usage | Prefill (tokens/s) | Decode (tokens/s)
Workstation (Xeon 4 + RTX 4090) | Qwen3-30B-A3B (8-bit) | 8.6 GB | 44 GB | 313 | 33 (single) → 50 (4-way)
Workstation (Xeon 4 + RTX 4090) | Qwen3-30B-A3B (4-bit) | 8.6 GB | 20 GB | 347.7 | 49.8 (single) → 98.8 (4-way)
Workstation (Xeon 4 + RTX 4090) | Qwen3-235B-A22B (4-bit) | 13 GB | 160 GB | 114.9 | 13.8 (single) → 24.4 (4-way)
Personal PC (Core i9-14900KF + RTX 4090) | Qwen3-30B-A3B (4-bit) | 8.6 GB | 20 GB | 240.0 | 12.0 (single) → 26.4 (4-way)
Personal PC (Core i9-14900KF + RTX 4090) | Qwen3-235B-A22B (4-bit) | 13 GB | 160 GB | 45 | 2.5 (single) → 6.0 (4-way)
You can see that, thanks to the AMX instruction optimizations, we achieve up to 347 tokens/s prefill performance in the workstation scenario. On consumer-grade CPUs, we’re able to run the large model (235B-A22) and deliver smooth performance on the smaller 30B-A3B. Even in terms of resource overhead, it appears that a high-end gaming laptop can handle 30B-A3B smoothly. After talking about the concept of AIPC for so long, we can finally see its feasibility.
To make it easier for everyone to understand the AMX optimizations we’ve open-sourced, we’ve prepared a brief document. We also extend our gratitude to Intel for their assistance.
# Introduction to AMX Instruction Set
Intel Advanced Matrix Extensions (AMX) are a set of specialized instruction extensions introduced for the x86 architecture starting with Sapphire Rapids (4th generation Xeon Scalable processors) and onward. AMX accelerates large-scale matrix computations at the hardware level, particularly for the compute-intensive parts of deep learning inference and machine learning workloads. By introducing the concept of Tile registers, it loads 2D sub-matrices into dedicated Tile registers and performs matrix multiply-accumulate operations at the register level, significantly improving throughput and energy efficiency.
Each CPU core contains 8 dedicated registers (tmm0–tmm7), with each register capable of holding up to 16 rows × 64 bytes of data to store 2D sub-matrices. Additionally, there is a 64-byte configuration register (TILECFG) used to describe each tmm register's number of rows, columns, and row stride.
The main AMX instructions are summarized as follows:
| Instruction Category | Instruction Names | Description |
|:---|:---|:---|
| Configuration Instructions | LDTILECFG, STTILECFG, TILERELEASE, TILEZERO | Configure/reset Tile registers and metadata |
| Load/Store Instructions | TILELOADD, TILELOADDT1, TILESTORED | Transfer data between memory and Tile registers |
| INT8 Computation Instructions | TDPBSSD, TDPBUSD, TDPBUUD, TDPBSUD | Perform multiply and accumulate operations on int8 sub-matrices within Tiles |
| BF16 Computation Instructions | TDPBF16PS | Perform multiply and accumulate operations on bfloat16 sub-matrices within Tiles |
To simplify development, Intel provides corresponding intrinsics, allowing C/C++ developers to leverage AMX's performance benefits without writing lengthy assembly code. For example:
```C++
#include <immintrin.h>
_tile_loadconfig(cfg_ptr);
_tile_loadd(tmm0, A_ptr, lda);
_tile_loadd(tmm1, B_ptr, ldb);
_tile_zero(tmm2)
_tile_dpbf16ps(tmm2, tmm0, tmm1);
_tile_stored(tmm2, C_ptr, ldc);
_tile_release();
```
The above code copies sub-matrices from memory (A_ptr, B_ptr) to Tile registers, calls the AMX BF16 compute instruction to multiply two sub-matrices, and then copies the result to memory (C_ptr).
Taking INT8 as an example, AMX can perform the multiplication of two 16×64 sub-matrices (32,768 multiply/add operations) with a single instruction in 16 CPU cycles, enabling each core to complete 2048 multiply/add operations per cycle — 8 times the performance of AVX-512. On an Intel Xeon 4 CPU, a single core can theoretically provide 4 TOPS of compute power, making it highly suitable for compute-intensive tasks on the CPU.
<p align="center">
<picture>
<img alt="amx_intro" src="../assets/amx_intro.png" width=60%>
</picture>
</p>
# AMX Kernel in KTransformers
Before version v0.3, KTransformers performed CPU matrix multiplications based on operators provided by llamafile. Unfortunately, llamafile's implementation had not yet been optimized for the AMX instruction set. This resulted in performance bottlenecks, even in strong hardware environments (such as Xeon 4th Gen + 4090), where inference speeds for large models like DeepSeek-V3 reached only 91 tokens/s during the prefill phase. The CPU thus remained a significant bottleneck. In long prompt scenarios, such performance is clearly unsatisfactory. To fully unleash CPU potential, we introduced a brand-new AMX optimization path along with multiple technical improvements in v0.3.
## 1. AMX Tiling-aware Memory Layout
AMX provides a high-throughput Tile register computation model, reducing instruction count and boosting theoretical throughput through coarse-grained matrix operations. However, to truly exploit AMX's potential, memory access efficiency is critical: because AMX transfers entire Tiles at once, misaligned Tiles and chaotic access patterns can cause severe cache misses, nullifying throughput gains.
Thus, in v0.3, we stopped directly memory-mapping GGUF-format files and introduced AMX Tiling-aware memory preprocessing during model loading. Specifically, expert weight matrices in MoE models are pre-rearranged into Tile-friendly sub-matrices whose shapes precisely match AMX Tile register dimensions, eliminating dynamic transposition overhead during inference. During rearrangement, we strictly align each sub-matrix's start address to 64 bytes to avoid cache line splits, and arrange sub-matrices sequentially according to computation access patterns, maximizing L1/L2 cache hit rates using compiler and hardware sequential prefetch capabilities.
For Int8 quantized formats, we adopted Symmetric Group-wise Quantization, with each column forming a group sharing a scale factor stored separately to maintain memory alignment for Tile data.
This AMX Tiling-aware memory layout design reduces memory latency while providing optimal input conditions for downstream computation kernels.
## 2. Cache-friendly AMX Kernel
During inference, we designed around the CPU’s multi-level cache hierarchy to perform computations in-place in high-speed caches, minimizing DRAM access frequency and overhead.
<p align="center">
<picture>
<img alt="amx" src="../assets/amx.png" width=60%>
</picture>
</p>
As shown in the figure,
- ① Expert weight matrices are first column-wise partitioned into multiple tasks dynamically scheduled across threads. Input activations are shared among tasks and typically reside in the shared L3 cache due to locality.
- ② Within each task, expert weights are row-wise partitioned into blocks, with block sizes finely tuned to ensure input activations, weights, and intermediate results stay within L2 cache, avoiding DRAM access.
- ③ ④ ⑤ Each block is treated as a set of sub-matrices matching AMX Tile registers, and during Tile-level computation, input Tiles (tmm0–tmm1) and expert Tiles (tmm2–tmm3) are loaded, and four AMX multiplication instructions directly generate and accumulate products into Tile registers (tmm4–tmm7), with output activations accumulated in Tile registers or L1 cache, avoiding additional data movement.
In short, we leveraged the cache hierarchy: every data element of expert weights and output activations accesses DRAM only once, with the other accesses hitting L2 or higher caches; input activations are accessed from DRAM only once and later hit in L3 or higher caches. This significantly reduces main memory traffic and improves overall execution efficiency.
## 3. AVX-512 Kernel Adaptation for Low Arithmetic Intensity Scenarios
Although AMX is highly efficient for large-scale matrix multiplication, it performs poorly under low arithmetic intensity, such as vector-matrix operations in the decode phase. This is because dispatching AMX Tiles involves fixed instruction overhead, which becomes wasteful when the data volume is insufficient to fill a Tile, causing reduced throughput.
<p align="center">
<picture>
<img alt="amx_avx" src="../assets/amx_avx.png" width=60%>
</picture>
</p>
To address this, we introduced a lightweight AVX-512 kernel as a complement. This kernel follows the same memory layout as the AMX kernel but replaces heavy AMX matrix-matrix multiplications with fine-grained AVX-512 vector-matrix multiplications, lowering latency for small matrices.
KTransformers dynamically selects between AMX and AVX-512 kernels at runtime based on arithmetic intensity: AMX kernels are automatically selected during long prompt prefill phases (where each expert handles more than 4 tokens on average), while short prompt prefill and decode phases dynamically switch to AVX-512 kernels. This ensures optimal efficiency under different arithmetic intensity conditions.
## 4. MoE Operator Fusion and Dynamic Scheduling
MoE models have many experts per layer, each requiring three matrix multiplications (Gate, Up, Down projections), leading to many small matrix multiplication tasks. Independently scheduling each small task would cause massive synchronization overhead between threads, dragging down overall inference speed.
Thus, we fused the same type of matrix computations for all experts in a layer into large unified tasks. Furthermore, as there are no data dependencies between Gate and Up projections, their computations can also be fused, ultimately consolidating a layer’s matrix multiplications into two major tasks, greatly reducing scheduling overhead.
To address load imbalance — especially during the prefill phase where expert activations can be highly skewed — we introduced a dynamic task scheduling strategy. Each matrix multiplication task is further split into multiple fine-grained sub-tasks, evenly distributed among CPU threads initially. Once a thread completes its assigned tasks, it atomically "steals" tasks from others, greatly mitigating load imbalance and achieving near-optimal CPU resource utilization.
Thanks to these optimizations, our kernel can achieve 21 TFLOPS of BF16 throughput and 35 TOPS of Int8 throughput on Xeon4 CPUs — about 4× faster than PyTorch’s general AMX kernel. For DeepSeek-V3, pairing a Xeon4 CPU with a single RTX 4090 GPU achieves 418 tokens/s end-to-end throughput, close to the performance of multi-machine, multi-GPU setups. KTransformers’ AMX kernel is the first AMX kernel specifically designed for MoE inference scenarios, significantly lowering the hardware barrier for large model deployment and enabling more developers to enjoy GPU cluster level inference experiences at lower cost.
<p align="center">
<picture>
<img alt="onednn_1" src="../assets/onednn_1.png" width=60%>
</picture>
</p>
# Usage
## Checking AMX Support
Before enabling the AMX-optimized kernels, it is important to verify whether your CPU supports the AMX instruction set. You can check AMX availability with the following command:
```bash
lscpu | grep -i amx
```
If your system supports AMX, you should see output similar to:
```bash
Flags: ... amx-bf16 amx-int8 amx-tile ...
```
If no amx-related flags are found, your CPU may not support AMX, or AMX may be disabled in BIOS settings. In that case, please ensure that:
- You are using a Sapphire Rapids (Xeon 4th Gen) or newer CPU.
- AMX support is enabled in your system BIOS under CPU feature settings.
## Enabling AMX in KTransformers
KTransformers allows users to easily switch between different backends through simple YAML configuration modifications. To enable AMX, modify the injection configuration of your experts by specifying backend as AMXInt8 or AMXBF16:
```YAML
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert parallelism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
```
**Note:** Currently, using AMXInt8 requires reading weights from a BF16 GGUF file and performing online quantization during model loading. This may cause slightly slower load times. Future versions will provide pre-quantized weights to eliminate this overhead.
![Image](https://github.com/user-attachments/assets/7c33c410-3af9-456f-aa67-5b24e19ba680)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment