Unverified Commit e3ab93c8 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Refactor CPU fused MOE (#30531)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent fc2ae6d6
......@@ -50,6 +50,7 @@ function cpu_tests() {
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py"
# Run basic model test
......
......@@ -330,7 +330,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
PUBLIC ${oneDNN_BINARY_DIR}/include
PRIVATE ${oneDNN_SOURCE_DIR}/src
)
target_link_libraries(dnnl_ext dnnl)
target_link_libraries(dnnl_ext dnnl torch)
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
list(APPEND LIBS dnnl_ext)
set(USE_ONEDNN ON)
......@@ -358,13 +358,13 @@ set(VLLM_EXT_SRC
"csrc/cpu/pos_encoding.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/scratchpad_manager.cpp"
"csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
set(VLLM_EXT_SRC
......
#ifndef CPU_ATTN_MACROS_H
#define CPU_ATTN_MACROS_H
#ifndef CPU_ARCH_MACROS_H
#define CPU_ARCH_MACROS_H
// x86_64
#ifdef __x86_64__
......@@ -26,7 +26,7 @@
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
const int n_mantissa_bits = 23; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \
auto fast_exp = [&](const vec_op::FP32Vec16& vec) __attribute__(( \
always_inline)) { \
__m512 values = vec.reg; \
auto less_ln_flt_min_mask = \
......@@ -98,7 +98,7 @@
poly = vbslq_f32(hi_mask, inf, poly); \
return vbslq_f32(lo_mask, zero, poly); \
}; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) \
auto fast_exp = [&](const vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \
float32x4x4_t result; \
result.val[0] = neon_expf(vec.reg.val[0]); \
......@@ -110,4 +110,4 @@
#endif // __aarch64__
#endif
\ No newline at end of file
#endif
......@@ -8,10 +8,8 @@
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "cpu_attn_macros.h"
#include "utils.hpp"
#include "cpu/cpu_arch_macros.h"
#include "cpu/utils.hpp"
namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16, NEON };
......@@ -378,12 +376,13 @@ class AttentionScheduler {
static constexpr int32_t MaxQTileIterNum = 128;
AttentionScheduler() : available_cache_size_(get_available_l2_size()) {}
AttentionScheduler()
: available_cache_size_(cpu_utils::get_available_l2_size()) {}
torch::Tensor schedule(const ScheduleInput& input) const {
const bool casual = input.casual;
const int32_t thread_num = omp_get_max_threads();
const int64_t cache_size = get_available_l2_size();
const int64_t cache_size = cpu_utils::get_available_l2_size();
const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
const int32_t kv_len_alignment = input.kv_block_alignment;
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
......@@ -659,7 +658,7 @@ class AttentionScheduler {
metadata_ptr->thread_num +
metadata_ptr->reduction_scratchpad_size_per_kv_head *
(use_gqa ? input.num_heads_kv : input.num_heads_q);
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(
scratchpad_size);
// metadata_ptr->print();
......@@ -667,7 +666,7 @@ class AttentionScheduler {
// test out of boundary access
// {
// float* cache_ptr =
// DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<float>();
// cpu_utils::ScratchPadManager::getl_scratchpad_manager()->get_data<float>();
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
// }
......@@ -749,27 +748,6 @@ class AttentionScheduler {
return std::max(rounded_tile_size, round_size);
}
static int64_t get_available_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
TORCH_CHECK_NE(l2_cache_size, -1);
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
private:
int64_t available_cache_size_;
};
......@@ -1402,7 +1380,7 @@ class AttentionMainLoop {
// init buffers
void* scratchpad_ptr =
DNNLScratchPadManager::get_dnnl_scratchpad_manager()
cpu_utils::ScratchPadManager::get_scratchpad_manager()
->get_data<void>();
AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr);
......@@ -1422,8 +1400,7 @@ class AttentionMainLoop {
}
}
const int64_t available_cache_size =
AttentionScheduler::get_available_l2_size();
const int64_t available_cache_size = cpu_utils::get_available_l2_size();
const int32_t default_tile_size =
AttentionScheduler::calcu_default_tile_size(
available_cache_size, head_dim, sizeof(kv_cache_t),
......
This diff is collapsed.
......@@ -352,6 +352,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(bool, void* ptr)
: reg((__m512)_mm512_stream_load_si512(ptr)) {}
// strided load
explicit FP32Vec16(const float* ptr, INT32Vec16 idx)
: reg(_mm512_i32gather_ps(idx.reg, ptr, 4)) {}
explicit FP32Vec16(__m512 data) : reg(data) {}
// de-pack 4 bit values
......@@ -408,6 +412,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
}
FP32Vec16 operator-() const {
return FP32Vec16(_mm512_xor_ps(reg, _mm512_set1_ps(-0.0f)));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg));
}
......
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "utils.hpp"
#include "cpu/cpu_types.hpp"
#include "cpu/utils.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
......@@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl(
// a simple schedule policy, just to hold more B tiles in L2 and make sure
// each thread has tasks
const int32_t n_partition_size = [&]() {
const int64_t cache_size = cpu_utils::get_l2_size();
const int64_t cache_size = cpu_utils::get_available_l2_size();
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
int64_t ps_thread_limit = n_size / thread_num;
ps_cache_limit =
......@@ -179,8 +178,8 @@ void cpu_gemm_wna16_impl(
const int64_t b_buffer_offset = 0;
const int64_t c_buffer_offset = b_buffer_size;
const int64_t buffer_size = b_buffer_size + c_buffer_size;
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
thread_num);
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size *
thread_num);
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
......@@ -190,9 +189,10 @@ void cpu_gemm_wna16_impl(
scalar_t* __restrict__ b_buffer = nullptr;
float* __restrict__ c_buffer = nullptr;
{
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
->get_data<uint8_t>() +
thread_id * buffer_size;
uint8_t* buffer_ptr =
cpu_utils::ScratchPadManager::get_scratchpad_manager()
->get_data<uint8_t>() +
thread_id * buffer_size;
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset);
}
......
......@@ -4,8 +4,8 @@
#include "common/memory_desc.hpp"
#include "common/memory.hpp"
#include "dnnl_helper.h"
#include "scratchpad_manager.h"
#include "cpu/utils.hpp"
#include "cpu/dnnl_helper.h"
static dnnl::engine& default_engine() {
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
......@@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
......@@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
......@@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
......@@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
}
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
......
......@@ -235,6 +235,39 @@ class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
}
}
static void pack_weight(const scalar_t* __restrict__ weight,
scalar_t* __restrict__ packed_weight,
const int32_t output_size, const int32_t input_size) {
constexpr int32_t elem_num_per_group = 4 / sizeof(scalar_t);
TORCH_CHECK_EQ(output_size % 16, 0);
TORCH_CHECK_EQ(input_size % (16 * elem_num_per_group), 0);
const int32_t output_group_num = output_size / 16;
const int32_t input_32b_num = input_size / elem_num_per_group;
for (int32_t output_group_idx = 0; output_group_idx < output_group_num;
++output_group_idx) {
const int32_t* __restrict__ weight_32b =
reinterpret_cast<const int32_t*>(weight);
int32_t* __restrict__ packed_weight_32b =
reinterpret_cast<int32_t*>(packed_weight);
for (int32_t output_idx = 0; output_idx < 16; ++output_idx) {
for (int32_t weight_offset = 0, packed_offset = 0;
weight_offset < input_32b_num;
++weight_offset, packed_offset += 16) {
packed_weight_32b[packed_offset] = weight_32b[weight_offset];
}
// update
weight_32b += input_32b_num;
packed_weight_32b += 1;
}
// update
weight += 16 * input_size;
packed_weight += 16 * input_size;
}
}
private:
alignas(64) __tilecfg amx_tile_config_;
int32_t curr_m_;
......
......@@ -13,6 +13,9 @@ namespace cpu_micro_gemm {
#define CPU_MICRO_GEMM_PARAMS \
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
// Note: weights for MicroGemm should be packed as (output_size / 16) contiguous
// blocks, means the logical shape of blocks is [16, input_size]. And the actual
// layout of blocks can be ISA-specific.
template <cpu_utils::ISA isa, typename scalar_t>
class MicroGemm {
public:
......@@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
curr_d += ldd;
}
}
template <int32_t n_size, typename scalar_t>
FORCE_INLINE void add_bias_epilogue(float* c_ptr, float* d_ptr,
scalar_t* __restrict__ bias_ptr,
const int32_t m, const int64_t ldc,
const int64_t ldd) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
static_assert(n_size % 16 == 0);
constexpr int32_t n_group_num = n_size / 16;
static_assert(n_group_num <= 16);
vec_op::FP32Vec16 bias_vecs[n_group_num];
scalar_t* __restrict__ curr_bias = bias_ptr;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
scalar_vec_t vec(curr_bias);
bias_vecs[i] = vec_op::FP32Vec16(vec);
curr_bias += 16;
});
float* curr_c = c_ptr;
float* curr_d = d_ptr;
for (int32_t i = 0; i < m; ++i) {
float* curr_c_iter = curr_c;
float* curr_d_iter = curr_d;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
c_vec_fp32.save(curr_d_iter);
curr_c_iter += 16;
curr_d_iter += 16;
});
curr_c += ldc;
curr_d += ldd;
}
}
} // namespace cpu_micro_gemm
#endif
......@@ -109,6 +109,25 @@ class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
}
// Note: pack contiguous weight [output_size, input_size] as contiguous
// packed weight [output_size / 16, input_size, 16]
static void pack_weight(const scalar_t* __restrict__ weight,
scalar_t* __restrict__ packed_weight,
const int32_t output_size, const int32_t input_size) {
TORCH_CHECK_EQ(output_size % 16, 0);
for (int32_t o_idx = 0; o_idx < output_size; ++o_idx) {
const scalar_t* __restrict__ curr_weight = weight + o_idx * input_size;
scalar_t* __restrict__ curr_packed_weight =
packed_weight + (o_idx / 16) * (16 * input_size) + o_idx % 16;
for (int32_t i_idx = 0; i_idx < input_size; ++i_idx) {
*curr_packed_weight = *curr_weight;
curr_packed_weight += 16;
++curr_weight;
}
}
}
};
} // namespace cpu_micro_gemm
......
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
#endif
......@@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint);
void prepack_moe_weight(const torch::Tensor& weight,
torch::Tensor& packed_weight, const std::string& isa);
void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const torch::Tensor& w13, const torch::Tensor& w2,
const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act,
const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
......@@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
// fused moe
#if defined(__AVX512F__)
ops.def(
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
"-> ()");
ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
......
......@@ -10,7 +10,7 @@
#define gettid() syscall(SYS_gettid)
#endif
#include "cpu_types.hpp"
#include "cpu/utils.hpp"
#ifdef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) {
......@@ -138,4 +138,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
return ss.str();
}
namespace cpu_utils {
ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void ScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
ScratchPadManager* ScratchPadManager::get_scratchpad_manager() {
static ScratchPadManager manager;
return &manager;
}
} // namespace cpu_utils
#endif
......@@ -2,19 +2,24 @@
#define UTILS_HPP
#include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h>
#include <ATen/cpu/Utils.h>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
#include "cpu/cpu_types.hpp"
namespace cpu_utils {
enum class ISA { AMX, VEC };
inline ISA get_isa(const std::string& isa) {
if (isa == "amx") {
return ISA::AMX;
} else if (isa == "vec") {
return ISA::VEC;
} else {
TORCH_CHECK(false, "Invalid isa type: " + isa);
}
}
template <typename T>
struct VecTypeTrait {
using vec_t = void;
......@@ -48,26 +53,66 @@ struct Counter {
int64_t acquire_counter() { return counter++; }
};
inline int64_t get_l2_size() {
inline int64_t get_available_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
assert(l2_cache_size != -1);
const uint32_t l2_cache_size = at::cpu::L2_cache_size();
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
template <int32_t alignment_v, typename T>
inline T round_up(T size) {
T alignment = alignment_v;
return (((size + alignment - 1) / alignment) * alignment);
}
template <int32_t alignment_v, typename T>
inline T round_down(T size) {
T alignment = alignment_v;
return (size / alignment) * alignment;
}
template <typename T>
inline void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
class ScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static ScratchPadManager* get_scratchpad_manager();
ScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
} // namespace cpu_utils
#endif
......@@ -147,7 +147,9 @@ WORKDIR /workspace/vllm
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get install -y --no-install-recommends vim numactl xz-utils
apt-get install -y --no-install-recommends vim numactl xz-utils make clangd-14
RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
......
cmake>=3.26.1
ninja
packaging>=24.2
setuptools>=77.0.3,<81.0.0
setuptools==77.0.3 # this version can reuse CMake build dir
setuptools-scm>=8
torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"
......
# Common dependencies
-r common.txt
setuptools==77.0.3 # this version can reuse CMake build dir
numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding
# Dependencies for CPUs
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
from vllm.platforms import current_platform
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
EXPERT_NUM = [
8,
]
HIDDEN_DIM = [128, 2880]
INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = ["silu", "swigluoai"]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def ref_fused_moe(
input: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
) -> torch.Tensor:
len_experts = w13.size(0)
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = input[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx].float()
curr_w13 = w13[i].float()
curr_w2 = w2[i].float()
curr_w13_bias = None
if w13_bias is not None:
curr_w13_bias = w13_bias[i].float()
curr_w2_bias = None
if w2_bias is not None:
curr_w2_bias = w2_bias[i].float()
gate_up = torch.nn.functional.linear(
tokens_for_this_expert, curr_w13, curr_w13_bias
)
# Note: to simulate the kernel implementation
gate_up = (
_CPU_MOE_ACT[activation]
.forward_native(gate_up)
.to(dtype=input.dtype)
.float()
)
expert_out = torch.nn.functional.linear(gate_up, curr_w2, curr_w2_bias)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(input.dtype)
)
return final_out
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("expert_num", EXPERT_NUM)
@pytest.mark.parametrize("hidden_size", HIDDEN_DIM)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_DIM)
@pytest.mark.parametrize("use_bias", USE_BIAS)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("act", ACT)
@pytest.mark.parametrize("isa", ISA)
def test_cpu_fused_moe(
batch_size: int,
expert_num: int,
hidden_size: int,
intermediate_size: int,
use_bias: bool,
dtype: torch.dtype,
act: str,
isa: str,
):
current_platform.seed_everything(0)
topk_num = max(expert_num // 2, 1)
up_dim = 2 * intermediate_size
input = torch.randn((batch_size, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
w13 = torch.randn((expert_num, up_dim, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
w2 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype) / (
0.5 * intermediate_size**0.5
)
router_logits = torch.randn((batch_size, expert_num), dtype=dtype)
w13_bias = None
w2_bias = None
if use_bias:
w13_bias = torch.randn((expert_num, up_dim), dtype=dtype) / (0.5 * up_dim**0.5)
w2_bias = torch.randn((expert_num, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
score = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk_num)
topk_ids = topk_ids.to(torch.int32)
ref_output = ref_fused_moe(
input,
w13,
w2,
w13_bias,
w2_bias,
topk_weight,
topk_ids,
act,
)
packed_w13 = cpu_prepack_moe_weight(w13, isa)
packed_w2 = cpu_prepack_moe_weight(w2, isa)
output = cpu_fused_moe(
input,
packed_w13,
packed_w2,
w13_bias,
w2_bias,
topk_weight,
topk_ids,
act,
isa,
)
atol, rtol = get_default_atol(output), get_default_rtol(output)
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment