Commit 74bb7fdc authored by qiyuxinlin's avatar qiyuxinlin
Browse files

Merge remote-tracking branch 'dev/support-amx-2'

parents ba92cf1a be4b27e8
---
BasedOnStyle: LLVM
ColumnLimit: 120 # 设置最大行宽为 100
IndentWidth: 2
---
......@@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int* locks // extra global storage for barrier synchronization
) {
int prob_m = *prob_m_ptr;
prob_m = min(prob_m, 1024);
const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);
if(prob_m > 16 * thread_m_blocks)
prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks));
......
......@@ -53,6 +53,21 @@ else ()
set(CMAKE_GENERATOR_PLATFORM_LWR "")
endif ()
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
find_package(Python3 REQUIRED COMPONENTS Interpreter)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
OUTPUT_VARIABLE ABI_FLAG
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE)
endif()
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
if (NOT MSVC)
if (LLAMA_STATIC)
add_link_options(-static)
......@@ -115,6 +130,37 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
message(STATUS "x86 detected")
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(HAS_AMX TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(HAS_AMX)
else()
message(STATUS "Compiler does NOT support AMX")
endif()
if (MSVC)
# instruction set detection for MSVC only
if (LLAMA_NATIVE)
......@@ -294,8 +340,12 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
if (HOST_IS_X86 AND HAS_AVX512 AND HAS_AMX)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
endif()
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5})
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h")
......
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2025-04-25 18:28:12
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2025-04-25 18:28:12
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
expert_num = 8
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
n_routed_experts = 8
layer_num = 10
qlen = 1024
CPUInfer = cpuinfer_ext.CPUInfer(65)
warm_up_iter = 100
test_iter = 100
def bench_moe(quant_mode: str):
with torch.inference_mode(mode=True):
if quant_mode == "bf16":
bytes_per_elem = 2.000000
elif quant_mode == "int8":
bytes_per_elem = 1.000000
else:
assert(False)
moes = []
gate_projs = []
up_projs = []
down_projs = []
for _ in range(layer_num):
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.moe.AMX_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr())
if quant_mode == "bf16":
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
CPUInfer.submit(moe.load_weights())
CPUInfer.sync()
elif quant_mode == "int8":
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
CPUInfer.submit(moe.load_weights())
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous()
weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
# warm up
for i in range(warm_up_iter):
CPUInfer.submit(
moes[i % layer_num].forward(
qlen,
n_routed_experts,
expert_ids[i % layer_num].data_ptr(),
weights[i % layer_num].data_ptr(),
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr(),
qlen_tensor.data_ptr()
)
)
CPUInfer.sync()
# test
start = time.perf_counter()
for i in range(test_iter):
CPUInfer.submit(
moes[i % layer_num].forward(
qlen,
n_routed_experts,
expert_ids[i % layer_num].data_ptr(),
weights[i % layer_num].data_ptr(),
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr(),
qlen_tensor.data_ptr()
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('Flops: ', hidden_size * intermediate_size * qlen * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GFLOPS')
print('')
bench_moe("bf16")
bench_moe("int8")
......@@ -30,7 +30,8 @@ void SharedMemBuffer::alloc(void* object, std::vector<std::pair<void**, uint64_t
if (buffer_) {
free(buffer_);
}
buffer_ = malloc(size);
buffer_ = std::aligned_alloc(64, size);
size_ = size;
for (auto& obj_requests : hist_requests_) {
for (auto& requests : obj_requests.second) {
......@@ -52,4 +53,4 @@ void SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests)
*(request.first) = (uint8_t*)buffer_ + offset;
offset += request.second;
}
}
}
\ No newline at end of file
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-08-05 04:49:08
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-05 06:36:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_SHAREDMEMBUFFER_H
#define CPUINFER_SHAREDMEMBUFFER_H
#include <cstdint>
#include <cstdlib>
#include <map>
#include <vector>
class SharedMemBuffer {
public:
SharedMemBuffer();
~SharedMemBuffer();
void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);
void dealloc(void* object);
private:
void* buffer_;
uint64_t size_;
std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;
void arrange(std::vector<std::pair<void**, uint64_t>> requests);
};
static SharedMemBuffer shared_mem_buffer;
#endif
\ No newline at end of file
......@@ -17,6 +17,11 @@
#include "operators/llamafile/linear.h"
#include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.h"
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
#include "operators/amx/moe.hpp"
#endif
#include "pybind11/functional.h"
#include "pybind11/operators.h"
#include "pybind11/pybind11.h"
......@@ -563,6 +568,78 @@ class MOEBindings {
};
};
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
template<class T>
class AMX_MOEBindings {
public:
class WarmUpBindings {
public:
struct Args {
CPUInfer *cpuinfer;
AMX_MOE<T> *moe;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&AMX_MOE<T>::warm_up, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {
Args *args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class LoadWeightsBindings {
public:
struct Args {
CPUInfer *cpuinfer;
AMX_MOE<T> *moe;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(&AMX_MOE<T>::load_weights, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {
Args *args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
public:
struct Args {
CPUInfer *cpuinfer;
AMX_MOE<T> *moe;
int qlen;
int k;
const uint64_t *expert_ids;
const float *weights;
const void *input;
void *output;
int *batch_size_tensor;
};
static void inner(void *args) {
Args *args_ = (Args *)args;
args_->cpuinfer->enqueue(
&AMX_MOE<T>::forward, args_->moe, args_->qlen, args_->k,
args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor);
}
static std::pair<intptr_t, intptr_t>
cpuinfer_interface(AMX_MOE<T> &moe, int qlen, int k, intptr_t expert_ids,
intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) {
Args *args = new Args{nullptr,
&moe,
qlen,
k,
(const uint64_t *)expert_ids,
(const float *)weights,
(const void *)input,
(void *)output,
(int *)batch_size_tensor};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
#endif
PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_<CPUInfer>(m, "CPUInfer")
.def(py::init<int>())
......@@ -621,6 +698,32 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)
py::class_<AMX_MOEConfig>(moe_module, "AMX_MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
int intermediate_size,
int max_len, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj) {
return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,
intermediate_size,
max_len, (void *)gate_proj,
(void *)up_proj, (void *)down_proj);
}));
py::class_<AMX_MOE<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE")
.def(py::init<AMX_MOEConfig>())
.def("warm_up", &AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface)
.def("load_weights", &AMX_MOEBindings<amx::GemmKernel224BF>::LoadWeightsBindings::cpuinfer_interface)
.def("forward", &AMX_MOEBindings<amx::GemmKernel224BF>::ForwardBindings::cpuinfer_interface);
py::class_<AMX_MOE<amx::GemmKernel224Int8>>(moe_module, "AMXInt8_MOE")
.def(py::init<AMX_MOEConfig>())
.def("warm_up", &AMX_MOEBindings<amx::GemmKernel224Int8>::WarmUpBindings::cpuinfer_interface)
.def("load_weights", &AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface)
.def("forward", &AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface);
#endif
auto kvcache_module = m.def_submodule("kvcache");
py::enum_<AnchorType>(kvcache_module, "AnchorType")
......
This diff is collapsed.
/**
* @Description :
* @Author : chenht2022
* @Date : 2025-04-25 18:28:12
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2025-04-25 18:28:12
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
#include <cstdint>
template <typename T>
T* offset_pointer(T* ptr, std::size_t byte_offset) {
return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr) + byte_offset);
}
template <typename T>
const T* offset_pointer(const T* ptr, std::size_t byte_offset) {
return reinterpret_cast<const T*>(reinterpret_cast<const char*>(ptr) + byte_offset);
}
template <typename T>
T* offset_pointer_row_major(T* t, int row, int col, std::size_t ld) {
return offset_pointer(t, row * ld) + col;
}
template <typename T>
T* offset_pointer_col_major(T* t, int row, int col, std::size_t ld) {
return offset_pointer(t, col * ld) + row;
}
static inline void avx512_copy_32xbf16(__m512i* src, __m512i* dst) {
_mm512_storeu_si512(dst, _mm512_loadu_si512(src));
}
static inline void avx512_32xfp32_to_32xbf16(__m512* src0, __m512* src1, __m512i* dst) {
_mm512_storeu_si512(dst, __m512i(_mm512_cvtne2ps_pbh(*src1, *src0)));
}
static inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512* dst1) {
_mm512_storeu_ps(dst0, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src))), 16)));
_mm512_storeu_ps(dst1, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src) + 1)), 16)));
}
\ No newline at end of file
/**
* @Description :
* @Author : chenht2022
* @Date : 2025-04-25 18:28:12
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2025-04-25 18:28:12
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_AMX_MOE_H
#define CPUINFER_OPERATOR_AMX_MOE_H
#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>
#include "../../cpu_backend/backend.h"
#include "../../cpu_backend/shared_mem_buffer.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "la/amx.hpp"
#ifdef USE_NUMA
#include <numa.h>
#include <numaif.h>
void *numa_alloc_aligned(size_t size, int node, size_t alignment) {
void *ptr = numa_alloc_onnode(size, node);
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
return ptr;
}
#endif
static inline __m512 exp_avx512(__m512 x) {
const __m512 log2e = _mm512_set1_ps(1.44269504089f);
const __m512 c1 = _mm512_set1_ps(0.69314718056f);
__m512 y = _mm512_mul_ps(x, log2e);
__m512i int_part = _mm512_cvtps_epi32(y);
__m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part));
const __m512 poly_1 = _mm512_set1_ps(0.9999999995f);
const __m512 poly_2 = _mm512_set1_ps(0.6931471805f);
const __m512 poly_3 = _mm512_set1_ps(0.2402265069f);
const __m512 poly_4 = _mm512_set1_ps(0.0555041087f);
const __m512 poly_5 = _mm512_set1_ps(0.0096181291f);
const __m512 poly_6 = _mm512_set1_ps(0.0013333558f);
__m512 frac_exp = _mm512_fmadd_ps(
frac_part, poly_6,
_mm512_fmadd_ps(frac_part, poly_5,
_mm512_fmadd_ps(frac_part, poly_4,
_mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1)))));
__m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part));
return _mm512_mul_ps(two_pow_i, frac_exp);
}
static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
__m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);
__m512 exp_neg_gate = exp_avx512(neg_gate_val);
__m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);
__m512 act_val = _mm512_div_ps(gate_val, denom);
return _mm512_mul_ps(act_val, up_val);
}
struct AMX_MOEConfig {
int expert_num;
int routed_expert_num;
int hidden_size;
int intermediate_size;
int max_len;
void *gate_proj;
void *up_proj;
void *down_proj;
AMX_MOEConfig() {}
AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,
void *gate_proj, void *up_proj, void *down_proj)
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),
intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),
down_proj(down_proj) {}
};
template <class T> class AMX_MOE {
private:
AMX_MOEConfig config_;
void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void *up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
ggml_bf16_t *m_local_input_; // [routed_expert_num * max_len * hidden_size]
ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size]
ggml_bf16_t *m_local_up_output_; // [routed_expert_num * max_len * intermediate_size]
ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size]
std::vector<std::vector<int>> m_local_pos_; // [max_len, routed_expert_num]
std::vector<int> m_local_num_; // [expert_num]
std::vector<int> m_expert_id_map_; // [expert_num]
std::vector<ggml_bf16_t *> m_local_input_ptr_; // [expert_num]
std::vector<ggml_bf16_t *> m_local_gate_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t *> m_local_up_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t *> m_local_down_output_ptr_; // [expert_num]
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
#ifdef USE_NUMA
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> gate_bb_numa_;
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> up_bb_numa_;
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> down_bb_numa_;
#else
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
#endif
public:
AMX_MOE(AMX_MOEConfig config) {
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
std::vector<std::pair<void **, uint64_t>> m_mem_requests;
m_mem_requests.push_back({(void **)&m_local_input_,
sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});
m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *
config_.max_len * config_.intermediate_size});
m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *
config_.max_len * config_.intermediate_size});
m_mem_requests.push_back({(void **)&m_local_down_output_,
sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});
std::vector<void *> gate_up_ba_ptr(config_.expert_num);
std::vector<void *> gate_bc_ptr(config_.expert_num);
std::vector<void *> up_bc_ptr(config_.expert_num);
std::vector<void *> down_ba_ptr(config_.expert_num);
std::vector<void *> down_bc_ptr(config_.expert_num);
for (int i = 0; i < config_.expert_num; i++) {
m_mem_requests.push_back(
{(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});
m_mem_requests.push_back(
{(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});
m_mem_requests.push_back(
{(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});
m_mem_requests.push_back(
{(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});
m_mem_requests.push_back(
{(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});
}
shared_mem_buffer.alloc(this, m_mem_requests);
m_local_pos_.resize(config_.max_len);
for (int i = 0; i < config_.max_len; i++) {
m_local_pos_[i].resize(config_.routed_expert_num);
}
m_expert_id_map_.resize(config_.expert_num);
m_local_num_.resize(config_.expert_num);
m_local_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
for (uint64_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i]));
gate_bc_.push_back(
std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, gate_bc_ptr[i]));
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, up_bc_ptr[i]));
down_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, down_ba_ptr[i]));
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, down_bc_ptr[i]));
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
gate_bb_numa_.resize(numa_nodes);
up_bb_numa_.resize(numa_nodes);
down_bb_numa_.resize(numa_nodes);
for (int j = 0; j < numa_nodes; j++) {
void *gate_bb_ptr =
numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);
gate_bb_numa_[j].push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
void *up_bb_ptr =
numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);
up_bb_numa_[j].push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
void *down_bb_ptr =
numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);
down_bb_numa_[j].push_back(
std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
}
#else
void *gate_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
gate_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
void *up_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
up_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
void *down_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
down_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
#endif
}
}
~AMX_MOE() { shared_mem_buffer.dealloc(this); }
void load_weights(Backend *backend) {
int nth = T::recommended_nth(config_.intermediate_size);
backend->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[&](int task_id) {
uint64_t expert_idx = task_id / nth;
int ith = task_id % nth;
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
for (int j = 0; j < numa_nodes; j++) {
gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +
expert_idx * config_.intermediate_size * config_.hidden_size,
ith, nth);
up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj +
expert_idx * config_.intermediate_size * config_.hidden_size,
ith, nth);
}
#else
gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +
expert_idx * config_.intermediate_size * config_.hidden_size,
ith, nth);
up_bb_[expert_idx]->from_mat(
(ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth);
#endif
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
backend->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[&](int task_id) {
uint64_t expert_idx = task_id / nth;
int ith = task_id % nth;
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
for (int j = 0; j < numa_nodes; j++) {
down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +
expert_idx * config_.hidden_size * config_.intermediate_size,
ith, nth);
}
#else
down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +
expert_idx * config_.hidden_size * config_.intermediate_size,
ith, nth);
#endif
},
nullptr);
}
void warm_up(Backend *backend) {}
void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output,
int *batch_size_tensor, Backend *backend) {
bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);
qlen = batch_size_tensor[0];
int activated_expert = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_num_[i] = 0;
}
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
for (int i = 0; i < config_.expert_num; i++) {
if (m_local_num_[i] > 0) {
m_expert_id_map_[activated_expert] = i;
activated_expert++;
}
}
uint64_t offset = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
backend->do_work_stealing_job(
qlen, nullptr,
[&](int i) {
for (int j = 0; j < k; j++) {
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
},
nullptr);
backend->do_work_stealing_job(
activated_expert, nullptr,
[&](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
},
nullptr);
int nth = T::recommended_nth(config_.intermediate_size);
backend->do_work_stealing_job(
nth * activated_expert, [&](int _) { T::config(); },
[&](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
#ifdef USE_NUMA
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx],
ith, nth, use_amx);
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith,
nth, use_amx);
#else
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);
#endif
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = act_fn(gate_val0, up_val0);
__m512 result1 = act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
}
}
},
nullptr);
backend->do_work_stealing_job(
activated_expert, nullptr,
[&](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
backend->do_work_stealing_job(
nth * activated_expert, [&](int _) { T::config(); },
[&](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
#ifdef USE_NUMA
amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],
down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx);
#else
amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx);
#endif
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
backend->do_work_stealing_job(
qlen, nullptr,
[&](int i) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e));
}
},
nullptr);
}
};
#endif
\ No newline at end of file
......@@ -17,12 +17,12 @@
#include <vector>
#include "../../cpu_backend/backend.h"
#include "../../cpu_backend/shared_mem_buffer.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct LinearConfig {
int input_size;
......
......@@ -17,12 +17,12 @@
#include <vector>
#include "../../cpu_backend/backend.h"
#include "../../cpu_backend/shared_mem_buffer.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct MLPConfig {
int hidden_size;
......
......@@ -17,12 +17,12 @@
#include <vector>
#include "../../cpu_backend/backend.h"
#include "../../cpu_backend/shared_mem_buffer.h"
#include "conversion.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct MOEConfig {
int expert_num;
......
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-08-05 04:49:08
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-05 06:36:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_SHAREDMEMBUFFER_H
#define CPUINFER_SHAREDMEMBUFFER_H
#include <cstdint>
#include <cstdlib>
#include <map>
#include <vector>
class SharedMemBuffer {
public:
SharedMemBuffer();
~SharedMemBuffer();
void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);
void dealloc(void* object);
private:
void* buffer_;
uint64_t size_;
std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;
void arrange(std::vector<std::pair<void**, uint64_t>> requests);
};
static SharedMemBuffer shared_mem_buffer;
#endif
\ No newline at end of file
# Introduction to AMX Instruction Set
Intel Advanced Matrix Extensions (AMX) are a set of specialized instruction extensions introduced for the x86 architecture starting with Sapphire Rapids (4th generation Xeon Scalable processors) and onward. AMX accelerates large-scale matrix computations at the hardware level, particularly for the compute-intensive parts of deep learning inference and machine learning workloads. By introducing the concept of Tile registers, it loads 2D sub-matrices into dedicated Tile registers and performs matrix multiply-accumulate operations at the register level, significantly improving throughput and energy efficiency.
Each CPU core contains 8 dedicated registers (tmm0–tmm7), with each register capable of holding up to 16 rows × 64 bytes of data to store 2D sub-matrices. Additionally, there is a 64-byte configuration register (TILECFG) used to describe each tmm register's number of rows, columns, and row stride.
The main AMX instructions are summarized as follows:
| Instruction Category | Instruction Names | Description |
|:---|:---|:---|
| Configuration Instructions | LDTILECFG, STTILECFG, TILERELEASE, TILEZERO | Configure/reset Tile registers and metadata |
| Load/Store Instructions | TILELOADD, TILELOADDT1, TILESTORED | Transfer data between memory and Tile registers |
| INT8 Computation Instructions | TDPBSSD, TDPBUSD, TDPBUUD, TDPBSUD | Perform multiply and accumulate operations on int8 sub-matrices within Tiles |
| BF16 Computation Instructions | TDPBF16PS | Perform multiply and accumulate operations on bfloat16 sub-matrices within Tiles |
To simplify development, Intel provides corresponding intrinsics, allowing C/C++ developers to leverage AMX's performance benefits without writing lengthy assembly code. For example:
```C++
#include <immintrin.h>
_tile_loadconfig(cfg_ptr);
_tile_loadd(tmm0, A_ptr, lda);
_tile_loadd(tmm1, B_ptr, ldb);
_tile_zero(tmm2)
_tile_dpbf16ps(tmm2, tmm0, tmm1);
_tile_stored(tmm2, C_ptr, ldc);
_tile_release();
```
The above code copies sub-matrices from memory (A_ptr, B_ptr) to Tile registers, calls the AMX BF16 compute instruction to multiply two sub-matrices, and then copies the result to memory (C_ptr).
Taking INT8 as an example, AMX can perform the multiplication of two 16×64 sub-matrices (32,768 multiply/add operations) with a single instruction in 16 CPU cycles, enabling each core to complete 2048 multiply/add operations per cycle — 8 times the performance of AVX-512. On an Intel Xeon 4 CPU, a single core can theoretically provide 4 TOPS of compute power, making it highly suitable for compute-intensive tasks on the CPU.
<p align="center">
<picture>
<img alt="amx_intro" src="../assets/amx_intro.png" width=60%>
</picture>
</p>
# AMX Kernel in KTransformers
Before version v0.3, KTransformers performed CPU matrix multiplications based on operators provided by llamafile. Unfortunately, llamafile's implementation had not yet been optimized for the AMX instruction set. This resulted in performance bottlenecks, even in strong hardware environments (such as Xeon 4th Gen + 4090), where inference speeds for large models like DeepSeek-V3 reached only 91 tokens/s during the prefill phase. The CPU thus remained a significant bottleneck. In long prompt scenarios, such performance is clearly unsatisfactory. To fully unleash CPU potential, we introduced a brand-new AMX optimization path along with multiple technical improvements in v0.3.
## 1. AMX Tiling-aware Memory Layout
AMX provides a high-throughput Tile register computation model, reducing instruction count and boosting theoretical throughput through coarse-grained matrix operations. However, to truly exploit AMX's potential, memory access efficiency is critical: because AMX transfers entire Tiles at once, misaligned Tiles and chaotic access patterns can cause severe cache misses, nullifying throughput gains.
Thus, in v0.3, we stopped directly memory-mapping GGUF-format files and introduced AMX Tiling-aware memory preprocessing during model loading. Specifically, expert weight matrices in MoE models are pre-rearranged into Tile-friendly sub-matrices whose shapes precisely match AMX Tile register dimensions, eliminating dynamic transposition overhead during inference. During rearrangement, we strictly align each sub-matrix's start address to 64 bytes to avoid cache line splits, and arrange sub-matrices sequentially according to computation access patterns, maximizing L1/L2 cache hit rates using compiler and hardware sequential prefetch capabilities.
For Int8 quantized formats, we adopted Symmetric Group-wise Quantization, with each column forming a group sharing a scale factor stored separately to maintain memory alignment for Tile data.
This AMX Tiling-aware memory layout design reduces memory latency while providing optimal input conditions for downstream computation kernels.
## 2. Cache-friendly AMX Kernel
During inference, we designed around the CPU’s multi-level cache hierarchy to perform computations in-place in high-speed caches, minimizing DRAM access frequency and overhead.
<p align="center">
<picture>
<img alt="amx" src="../assets/amx.png" width=60%>
</picture>
</p>
As shown in the figure,
- ① Expert weight matrices are first column-wise partitioned into multiple tasks dynamically scheduled across threads. Input activations are shared among tasks and typically reside in the shared L3 cache due to locality.
- ② Within each task, expert weights are row-wise partitioned into blocks, with block sizes finely tuned to ensure input activations, weights, and intermediate results stay within L2 cache, avoiding DRAM access.
- ③ ④ ⑤ Each block is treated as a set of sub-matrices matching AMX Tile registers, and during Tile-level computation, input Tiles (tmm0–tmm1) and expert Tiles (tmm2–tmm3) are loaded, and four AMX multiplication instructions directly generate and accumulate products into Tile registers (tmm4–tmm7), with output activations accumulated in Tile registers or L1 cache, avoiding additional data movement.
In short, we leveraged the cache hierarchy: every data element of expert weights and output activations accesses DRAM only once, with the other accesses hitting L2 or higher caches; input activations are accessed from DRAM only once and later hit in L3 or higher caches. This significantly reduces main memory traffic and improves overall execution efficiency.
## 3. AVX-512 Kernel Adaptation for Low Arithmetic Intensity Scenarios
Although AMX is highly efficient for large-scale matrix multiplication, it performs poorly under low arithmetic intensity, such as vector-matrix operations in the decode phase. This is because dispatching AMX Tiles involves fixed instruction overhead, which becomes wasteful when the data volume is insufficient to fill a Tile, causing reduced throughput.
<p align="center">
<picture>
<img alt="amx_avx" src="../assets/amx_avx.png" width=60%>
</picture>
</p>
To address this, we introduced a lightweight AVX-512 kernel as a complement. This kernel follows the same memory layout as the AMX kernel but replaces heavy AMX matrix-matrix multiplications with fine-grained AVX-512 vector-matrix multiplications, lowering latency for small matrices.
KTransformers dynamically selects between AMX and AVX-512 kernels at runtime based on arithmetic intensity: AMX kernels are automatically selected during long prompt prefill phases (where each expert handles more than 4 tokens on average), while short prompt prefill and decode phases dynamically switch to AVX-512 kernels. This ensures optimal efficiency under different arithmetic intensity conditions.
## 4. MoE Operator Fusion and Dynamic Scheduling
MoE models have many experts per layer, each requiring three matrix multiplications (Gate, Up, Down projections), leading to many small matrix multiplication tasks. Independently scheduling each small task would cause massive synchronization overhead between threads, dragging down overall inference speed.
Thus, we fused the same type of matrix computations for all experts in a layer into large unified tasks. Furthermore, as there are no data dependencies between Gate and Up projections, their computations can also be fused, ultimately consolidating a layer’s matrix multiplications into two major tasks, greatly reducing scheduling overhead.
To address load imbalance — especially during the prefill phase where expert activations can be highly skewed — we introduced a dynamic task scheduling strategy. Each matrix multiplication task is further split into multiple fine-grained sub-tasks, evenly distributed among CPU threads initially. Once a thread completes its assigned tasks, it atomically "steals" tasks from others, greatly mitigating load imbalance and achieving near-optimal CPU resource utilization.
Thanks to these optimizations, our kernel can achieve 21 TFLOPS of BF16 throughput and 35 TOPS of Int8 throughput on Xeon4 CPUs — about 4× faster than PyTorch’s general AMX kernel. For DeepSeek-V3, pairing a Xeon4 CPU with a single RTX 4090 GPU achieves 418 tokens/s end-to-end throughput, close to the performance of multi-machine, multi-GPU setups. KTransformers’ AMX kernel is the first AMX kernel specifically designed for MoE inference scenarios, significantly lowering the hardware barrier for large model deployment and enabling more developers to enjoy GPU cluster level inference experiences at lower cost.
<p align="center">
<picture>
<img alt="onednn_1" src="../assets/onednn_1.png" width=60%>
</picture>
</p>
# Usage
## Checking AMX Support
Before enabling the AMX-optimized kernels, it is important to verify whether your CPU supports the AMX instruction set. You can check AMX availability with the following command:
```bash
lscpu | grep -i amx
```
If your system supports AMX, you should see output similar to:
```bash
Flags: ... amx-bf16 amx-int8 amx-tile ...
```
If no amx-related flags are found, your CPU may not support AMX, or AMX may be disabled in BIOS settings. In that case, please ensure that:
- You are using a Sapphire Rapids (Xeon 4th Gen) or newer CPU.
- AMX support is enabled in your system BIOS under CPU feature settings.
## Enabling AMX in KTransformers
KTransformers allows users to easily switch between different backends through simple YAML configuration modifications. To enable AMX, modify the injection configuration of your experts by specifying backend as AMXInt8 or AMXBF16:
```YAML
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert parallelism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
```
**Note:** Currently, using AMXInt8 requires reading weights from a BF16 GGUF file and performing online quantization during model loading. This may cause slightly slower load times. Future versions will provide pre-quantized weights to eliminate this overhead.
\ No newline at end of file
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2MoE model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen2MoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a
Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B").
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2MoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):
Intermediate size of the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 60):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `False`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
```python
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
>>> # Initializing a Qwen2MoE style configuration
>>> configuration = Qwen2MoeConfig()
>>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration
>>> model = Qwen2MoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_moe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
decoder_sparse_step=1,
moe_intermediate_size=1408,
shared_expert_intermediate_size=5632,
num_experts_per_tok=4,
num_experts=60,
norm_topk_prob=False,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment