Unverified Commit 7f829be7 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Refactor CPU attention backend (#27954)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent e1710393
...@@ -132,7 +132,7 @@ steps: ...@@ -132,7 +132,7 @@ steps:
queue: cpu_queue_postmerge queue: cpu_queue_postmerge
commands: commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest"
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
env: env:
......
...@@ -49,6 +49,7 @@ function cpu_tests() { ...@@ -49,6 +49,7 @@ function cpu_tests() {
# Run kernel tests # Run kernel tests
docker exec cpu-test-"$NUMA_NODE" bash -c " docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e set -e
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/test_onednn.py" pytest -x -v -s tests/kernels/test_onednn.py"
# Run basic model test # Run basic model test
...@@ -116,4 +117,4 @@ function cpu_tests() { ...@@ -116,4 +117,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 40 mins. # All of CPU tests are expected to be finished less than 40 mins.
export -f cpu_tests export -f cpu_tests
timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" timeout 2.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
...@@ -15,6 +15,7 @@ endif() ...@@ -15,6 +15,7 @@ endif()
# #
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
include_directories("${CMAKE_SOURCE_DIR}/csrc") include_directories("${CMAKE_SOURCE_DIR}/csrc")
...@@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ...@@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(ENABLE_AVX512VNNI OFF) set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
endif() endif()
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
set(ENABLE_AMXBF16 ON)
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
endif()
elseif (AVX2_FOUND) elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2") list(APPEND CXX_COMPILE_FLAGS "-mavx2")
...@@ -275,7 +292,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON ...@@ -275,7 +292,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_VERBOSE "OFF") set(ONEDNN_VERBOSE "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
FetchContent_MakeAvailable(oneDNN) FetchContent_MakeAvailable(oneDNN)
set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE})
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
target_include_directories( target_include_directories(
dnnl_ext dnnl_ext
...@@ -305,14 +325,14 @@ endif() ...@@ -305,14 +325,14 @@ endif()
# #
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp" "csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp" "csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp" "csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp" "csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp" "csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp" "csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp") "csrc/cpu/cpu_attn.cpp"
"csrc/cpu/scratchpad_manager.cpp"
"csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
......
This diff is collapsed.
#include <map>
#include <vector>
#include "cpu_types.hpp"
#if defined(__x86_64__)
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif
namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& mapping_pairs,
const int element_num_per_block,
const int layer_num) {
const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset =
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t* target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
}
}
}
template <typename scalar_t>
void reshape_and_cache_cpu_impl(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size;
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx >= 0) {
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx =
token_idx * value_stride + head_idx * head_size;
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
const int64_t target_offset =
src_key_idx * block_size + block_offset * x;
for (int i = 0; i < x; ++i) {
target_key_head_ptr[target_offset + i] =
src_key_head_ptr[src_key_idx + i];
}
}
for (int src_value_idx = 0; src_value_idx < head_size;
++src_value_idx) {
const int64_t target_offset =
src_value_idx * block_size + block_offset;
target_value_head_ptr[target_offset] =
src_value_head_ptr[src_value_idx];
}
}
}
}
}
}; // namespace
template <typename scalar_t>
void concat_and_cache_mla_cpu_impl(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int num_tokens, //
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size //
) {
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
continue;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src,
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
int size, int offset) {
for (int i = 0; i < size; i++) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
dst[dst_idx] = src[src_idx];
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
}
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
const int element_num_per_block = key_caches[0][0].numel();
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
}
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
}
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
TORCH_CHECK(kv_cache_dtype != "fp8");
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
VLLM_DISPATCH_FLOATING_TYPES(
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
concat_and_cache_mla_cpu_impl<scalar_t>(
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
kv_lora_rank, pe_dim, block_size);
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
});
}
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}
#include "cpu_attn_vec.hpp"
#include "cpu_attn_vec16.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu_attn_amx.hpp"
#define AMX_DISPATCH(...) \
case cpu_attention::ISA::AMX: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::AMX, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
return __VA_ARGS__(); \
}
#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \
[&] { \
switch (HEAD_DIM) { \
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \
default: { \
TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \
std::to_string(HEAD_DIM)); \
} \
} \
}()
#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
case cpu_attention::ISA::VEC16: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC16, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
default: { \
TORCH_CHECK(false, "Invalid CPU attention ISA type."); \
} \
} \
}()
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
const torch::Tensor& seq_lens, at::ScalarType dtype,
const torch::Tensor& query_start_loc, const bool casual,
const int64_t window_size, const std::string& isa_hint,
const bool enable_kv_split) {
cpu_attention::ISA isa;
if (isa_hint == "amx") {
isa = cpu_attention::ISA::AMX;
} else if (isa_hint == "vec") {
isa = cpu_attention::ISA::VEC;
} else if (isa_hint == "vec16") {
isa = cpu_attention::ISA::VEC16;
} else {
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
}
cpu_attention::AttentionScheduler::ScheduleInput input;
input.num_reqs = num_req;
input.num_heads_q = num_heads_q;
input.num_heads_kv = num_heads_kv;
input.head_dim = head_dim;
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
input.seq_lens = seq_lens.data_ptr<int32_t>();
if (window_size != -1) {
input.left_sliding_window_size = window_size - 1;
if (casual) {
input.right_sliding_window_size = 0;
} else {
input.right_sliding_window_size = window_size - 1;
}
} else {
input.left_sliding_window_size = -1;
if (casual) {
input.right_sliding_window_size = 0;
} else {
input.right_sliding_window_size = -1;
}
}
input.casual = casual;
input.isa = isa;
input.enable_kv_split = enable_kv_split;
TORCH_CHECK(casual, "Only supports casual mask for now.");
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
CPU_ATTN_DISPATCH_IMPL(isa, [&]() {
input.elem_size = sizeof(scalar_t);
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
input.output_buffer_elem_size =
sizeof(attn_impl::partial_output_buffer_t);
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
});
});
});
cpu_attention::AttentionScheduler scheduler;
torch::Tensor metadata = scheduler.schedule(input);
return metadata;
}
void cpu_attn_reshape_and_cache(
const torch::Tensor& key, // [token_num, head_num, head_size]
const torch::Tensor& value, // [token_num, head_num, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor& slot_mapping, const std::string& isa) {
TORCH_CHECK_EQ(key.dim(), 3);
TORCH_CHECK_EQ(value.dim(), 3);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
TORCH_CHECK_EQ(key.stride(2), 1);
TORCH_CHECK_EQ(value.stride(2), 1);
const int64_t token_num = key.size(0);
const int64_t key_token_num_stride = key.stride(0);
const int64_t value_token_num_stride = value.stride(0);
const int64_t head_num = value.size(1);
const int64_t key_head_num_stride = key.stride(1);
const int64_t value_head_num_stride = value.stride(1);
const int64_t num_blocks = key_cache.size(0);
const int64_t num_blocks_stride = key_cache.stride(0);
const int64_t cache_head_num_stride = key_cache.stride(1);
const int64_t block_size = key_cache.size(2);
const int64_t block_size_stride = key_cache.stride(2);
const int64_t head_dim = key.size(-1);
cpu_attention::ISA isa_tag = [&]() {
if (isa == "amx") {
return cpu_attention::ISA::AMX;
} else if (isa == "vec") {
return cpu_attention::ISA::VEC;
} else if (isa == "vec16") {
return cpu_attention::ISA::VEC16;
} else {
TORCH_CHECK(false, "Invalid ISA type: " + isa);
}
}();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() {
attn_impl::reshape_and_cache(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), token_num,
key_token_num_stride, value_token_num_stride, head_num,
key_head_num_stride, value_head_num_stride, num_blocks,
num_blocks_stride, cache_head_num_stride, block_size,
block_size_stride);
});
});
});
}
void cpu_attention_with_kv_cache(
const torch::Tensor& query, // [num_tokens, num_heads, head_size]
const torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& output, // [num_tokens, num_heads, head_size]
const torch::Tensor& query_start_loc, // [num_tokens + 1]
const torch::Tensor& seq_lens, // [num_tokens]
const double scale, const bool causal,
const std::optional<torch::Tensor>& alibi_slopes, // [num_heads]
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, // [num_tokens, max_block_num]
const double softcap, const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux // [num_heads]
) {
TORCH_CHECK_EQ(query.dim(), 3);
TORCH_CHECK_EQ(query.stride(2), 1);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
cpu_attention::AttentionInput input;
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
scheduler_metadata.data_ptr());
input.num_tokens = query.size(0);
input.num_heads = query.size(1);
input.num_kv_heads = key_cache.size(1);
input.block_size = key_cache.size(2);
input.query = query.data_ptr();
input.query_num_tokens_stride = query.stride(0);
input.query_num_heads_stride = query.stride(1);
input.cache_num_blocks_stride = key_cache.stride(0);
input.cache_num_kv_heads_stride = key_cache.stride(1);
input.blt_num_tokens_stride = block_table.stride(0);
input.key_cache = key_cache.data_ptr();
input.value_cache = value_cache.data_ptr();
input.output = output.data_ptr();
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
input.seq_lens = seq_lens.data_ptr<int32_t>();
input.block_table = block_table.data_ptr<int32_t>();
input.alibi_slopes =
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
// For now sink must be bf16
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
input.scale = scale;
input.causal = causal;
input.sliding_window_left = sliding_window_left;
input.sliding_window_right = sliding_window_right;
if (input.causal) {
// to make boundary calculation easier
input.sliding_window_right = 0;
}
float softcap_fp32 = softcap;
input.softcap = softcap_fp32;
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] {
CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() {
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
mainloop(&input);
});
});
});
}
This diff is collapsed.
This diff is collapsed.
#ifndef CPU_ATTN_MACROS_H
#define CPU_ATTN_MACROS_H
// x86_64
#ifdef __x86_64__
#define FAST_SPINNING _mm_pause();
#ifdef __AVX512F__
#define DEFINE_FAST_EXP \
const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); \
const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); \
const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); \
const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); \
const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); \
const __m512 vec_exp_log2ef = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); \
const __m512 vec_half = _mm512_set1_ps(0.5f); \
const __m512 vec_one = _mm512_set1_ps(1.f); \
const __m512 vec_zero = _mm512_set1_ps(0.f); \
const __m512 vec_two = _mm512_set1_ps(2.f); \
const __m512 vec_ln2f = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); \
const __m512 vec_ln_flt_min = \
_mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); \
const __m512 vec_ln_flt_max = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
const int n_mantissa_bits = 23; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \
always_inline)) { \
__m512 values = vec.reg; \
auto less_ln_flt_min_mask = \
_mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); \
auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); \
vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); \
auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); \
auto vec_fx_i = _mm512_cvt_roundps_epi32( \
vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); \
vec_fx = _mm512_cvtepi32_ps(vec_fx_i); \
auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); \
auto vec_res = \
_mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); \
auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); \
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); \
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); \
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); \
auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); \
vec_two_pow_n = _mm512_mask_blend_ps(less_ln_flt_min_mask, \
vec_two_pow_n, vec_zero); \
vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); \
vec_res = _mm512_mul_ps(vec_res, vec_two); \
vec_op::FP32Vec16 res(vec_res); \
return res; \
};
#endif
#endif
#endif
\ No newline at end of file
#ifndef CPU_ATTN_VEC_HPP
#define CPU_ATTN_VEC_HPP
#include "cpu_attn_impl.hpp"
namespace cpu_attention {
namespace {
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
template <typename kv_cache_t>
class TileGemm82 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
switch (m_size) {
case 1:
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 2:
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 3:
case 4:
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 5:
case 6:
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 7:
case 8:
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
}
}
template <int32_t M>
static void gemm_micro(float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size, const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(0 < M <= 8);
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
kv_cache_t* __restrict__ curr_b_0 = b_tile;
kv_cache_t* __restrict__ curr_b_1 = b_tile + 16;
float* __restrict__ curr_c_0 = c_tile;
float* __restrict__ curr_c_1 = c_tile + 16;
vec_op::FP32Vec16 c_regs[M * 2];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
float* __restrict__ curr_m_c_1 = curr_c_1;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
// update
curr_m_c_0 += ldc;
curr_m_c_1 += ldc;
});
}
float* __restrict__ curr_a = a_tile;
for (int32_t k = 0; k < dynamic_k_size; ++k) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
load_vec_t b_1_reg(curr_b_1);
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
float* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
float v = *curr_m_a;
vec_op::FP32Vec16 a_reg(v);
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += ldb;
curr_b_1 += ldb;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2].save(curr_c_0);
c_regs[i * 2 + 1].save(curr_c_1);
// update
curr_c_0 += ldc;
curr_c_1 += ldc;
});
}
};
} // namespace
// This is a general but naive implementation based on vector instructions
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
32; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
32; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 8;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VEC;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemm82<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
// Copy q to q_buffer and cast it to fp32
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
float* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
static_assert(head_dim % 16 == 0);
constexpr int32_t unroll_size = head_dim / 16;
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
scalar_t* __restrict__ curr_q =
src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
float* __restrict__ curr_q_buffer =
q_buffer + q_num_idx * q_heads_per_kv * head_dim +
q_head_idx * head_dim;
vec_op::unroll_loop<int32_t, unroll_size>([&](int32_t i) {
load_vec_t vec(curr_q);
vec_op::FP32Vec16 fp32_vec(vec);
fp32_vec = fp32_vec * scale_vec;
fp32_vec.save(curr_q_buffer);
curr_q += 16;
curr_q_buffer += 16;
});
}
}
}
// reshape K as column-major and V as row-major
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key as column-major
const scalar_t* key_start_ptr = key +
token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_cache_start_ptr =
key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_cache_start_ptr[j] = key_start_ptr[i];
}
}
{
// Write Value as row-major
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset * head_dim;
std::memcpy(value_cache_start_ptr, value_start_ptr,
sizeof(scalar_t) * head_dim);
}
}
}
}
};
} // namespace cpu_attention
#endif
#ifndef CPU_ATTN_VEC16_HPP
#define CPU_ATTN_VEC16_HPP
#include "cpu_attn_vec.hpp"
namespace cpu_attention {
namespace {
// 16-1-16 pattern, 16 regs for A, 1 regs for B, 16 regs for C, [16, K] @ [k,
// 16]
template <typename kv_cache_t>
class TileGemm161 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
switch (m_size) {
case 1:
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 2:
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 3:
case 4:
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 5:
case 6:
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 7:
case 8:
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 9:
case 10:
case 11:
case 12:
gemm_micro<12>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 13:
case 14:
case 15:
case 16:
gemm_micro<16>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
}
}
template <int32_t M>
static void gemm_micro(float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size, const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(0 < M <= 16);
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
kv_cache_t* __restrict__ curr_b_0 = b_tile;
float* __restrict__ curr_c_0 = c_tile;
vec_op::FP32Vec16 c_regs[M];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i] = vec_op::FP32Vec16(curr_m_c_0);
// update
curr_m_c_0 += ldc;
});
}
float* __restrict__ curr_a = a_tile;
for (int32_t k = 0; k < dynamic_k_size; ++k) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
float* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
float v = *curr_m_a;
vec_op::FP32Vec16 a_reg(v);
c_regs[i] = c_regs[i] + a_reg * fp32_b_0_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += ldb;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i].save(curr_c_0);
// update
curr_c_0 += ldc;
});
}
};
} // namespace
// This is a general but naive implementation based on vector instructions
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VEC16, scalar_t, head_dim>
: public AttentionImpl<ISA::VEC, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
16; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
16; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 16;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VEC16;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemm161<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
};
} // namespace cpu_attention
#endif
...@@ -40,6 +40,23 @@ namespace vec_op { ...@@ -40,6 +40,23 @@ namespace vec_op {
#define FORCE_INLINE __attribute__((always_inline)) inline #define FORCE_INLINE __attribute__((always_inline)) inline
// Function to get the timestamp using RDTSCP
FORCE_INLINE uint64_t bench_timestamp() {
unsigned int cycles_low, cycles_high;
asm volatile(
".intel_syntax noprefix\n\t"
"CPUID\n\t" // Serialize instruction stream to ensure previous
// instructions complete
"RDTSCP\n\t" // Read TSC and core ID
"mov %0, edx\n\t" // Store high 32 bits of TSC
"mov %1, eax\n\t" // Store low 32 bits of TSC
".att_syntax"
: "=r"(cycles_high), "=r"(cycles_low)::"rax", "rbx", "rcx",
"rdx" // Clobbered registers
);
return (uint64_t)cycles_high << 32 | cycles_low;
}
namespace { namespace {
template <typename T, T... indexes, typename F> template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) { constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
...@@ -407,6 +424,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -407,6 +424,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float reduce_min() const { return _mm512_reduce_min_ps(reg); } float reduce_min() const { return _mm512_reduce_min_ps(reg); }
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
template <int group_size> template <int group_size>
float reduce_sub_sum(int idx) { float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
...@@ -446,9 +465,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -446,9 +465,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
explicit FP32Vec16(const FP32Vec16& data)
: reg_low(data.reg_low), reg_high(data.reg_high) {}
explicit FP32Vec16(const FP32Vec4& data) explicit FP32Vec16(const FP32Vec4& data)
: reg_low((__m256)_mm256_inserti128_si256( : reg_low((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)), _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
...@@ -504,6 +520,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -504,6 +520,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
_mm256_div_ps(reg_high, b.reg_high)); _mm256_div_ps(reg_high, b.reg_high));
} }
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(_mm256_max_ps(reg_low, b.reg_low),
_mm256_max_ps(reg_high, b.reg_high));
}
float reduce_max() const {
__m256 v = _mm256_max_ps(reg_low, reg_high);
// Permute to compare elements within 128-bit lanes
__m256 v_shuffled = _mm256_permute_ps(
v, 0b00001011); // Swap halves within each 128-bit lane
__m256 v_max = _mm256_max_ps(v, v_shuffled);
v_shuffled = _mm256_permute_ps(
v_max, 0b00000001); // Shuffle elements within each 128-bit lane
v_max = _mm256_max_ps(v_max, v_shuffled);
// Permute to compare elements between 128-bit lanes
v_shuffled =
_mm256_permute2f128_ps(v_max, v_max, 0b00000001); // Swap 128-bit lanes
v_max = _mm256_max_ps(v_max, v_shuffled);
// At this point, the maximum value is present in all elements of v_max.
// Extract the first element for the scalar result.
return _mm256_cvtss_f32(v_max); // Extract the lowest 32-bit float
}
float reduce_sum() const { float reduce_sum() const {
FP32Vec8 low = FP32Vec8(reg_low); FP32Vec8 low = FP32Vec8(reg_low);
FP32Vec8 high = FP32Vec8(reg_high); FP32Vec8 high = FP32Vec8(reg_high);
...@@ -642,7 +684,7 @@ inline FP16Vec16::FP16Vec16(const FP32Vec16& v) ...@@ -642,7 +684,7 @@ inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
: reg(_mm256_insertf128_si256( : reg(_mm256_insertf128_si256(
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), _mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {} FP16Vec8(FP32Vec8(v.reg_high)).reg, 1)) {}
#endif #endif
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "common/memory.hpp" #include "common/memory.hpp"
#include "dnnl_helper.h" #include "dnnl_helper.h"
#include "scratchpad_manager.h"
static dnnl::engine& default_engine() { static dnnl::engine& default_engine() {
static dnnl::engine engine(dnnl::engine::kind::cpu, 0); static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
...@@ -22,23 +23,6 @@ void release_dnnl_matmul_handler(int64_t handler) { ...@@ -22,23 +23,6 @@ void release_dnnl_matmul_handler(int64_t handler) {
delete ptr; delete ptr;
} }
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
template <typename KT, typename VT> template <typename KT, typename VT>
class DNNLPrimitiveCache { class DNNLPrimitiveCache {
public: public:
......
...@@ -59,30 +59,6 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() { ...@@ -59,30 +59,6 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() {
return DNNLType<std::decay_t<T>>::type; return DNNLType<std::decay_t<T>>::type;
} }
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
class DNNLMatMulPrimitiveHandler { class DNNLMatMulPrimitiveHandler {
public: public:
virtual ~DNNLMatMulPrimitiveHandler() = default; virtual ~DNNLMatMulPrimitiveHandler() = default;
......
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
#endif
...@@ -192,7 +192,7 @@ class SHMManager { ...@@ -192,7 +192,7 @@ class SHMManager {
const int group_size) const int group_size)
: _rank(rank), : _rank(rank),
_group_size(group_size), _group_size(group_size),
_thread_num(torch::get_num_threads()), _thread_num(omp_get_max_threads()),
_shm_names({""}), _shm_names({""}),
_shared_mem_ptrs({nullptr}), _shared_mem_ptrs({nullptr}),
_shm_ctx(nullptr) { _shm_ctx(nullptr) {
......
...@@ -74,25 +74,35 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, ...@@ -74,25 +74,35 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni); at::ScalarType out_dtype, bool is_vnni);
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
const torch::Tensor& seq_lens, at::ScalarType dtype,
const torch::Tensor& query_start_loc, const bool casual,
const int64_t window_size, const std::string& isa_hint,
const bool enable_kv_split);
void cpu_attn_reshape_and_cache(const torch::Tensor& key,
const torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
const torch::Tensor& slot_mapping,
const std::string& isa);
void cpu_attention_with_kv_cache(
const torch::Tensor& query, const torch::Tensor& key_cache,
const torch::Tensor& value_cache, torch::Tensor& output,
const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
const double scale, const bool causal,
const std::optional<torch::Tensor>& alibi_slopes,
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, const double softcap,
const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
ops.def( ops.def(
"dynamic_4bit_int_moe(" "dynamic_4bit_int_moe("
"Tensor x, Tensor topk_ids, Tensor topk_weights," "Tensor x, Tensor topk_ids, Tensor topk_weights,"
...@@ -102,20 +112,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -102,20 +112,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu); ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
...@@ -259,37 +255,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -259,37 +255,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("int8_scaled_mm_with_quant", torch::kCPU, ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
&int8_scaled_mm_with_quant); &int8_scaled_mm_with_quant);
#endif #endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // CPU attention kernels
// Cache ops ops.def(
// Swap in (out) the cache blocks from src to dst. "get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
cache_ops.def( "int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); "query_start_loc, bool casual, int window_size, str isa_hint, bool "
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); "enable_kv_split) -> Tensor",
&get_scheduler_metadata);
// Copy the cache blocks from src to dst. ops.def(
cache_ops.def( "cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " "key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
"Tensor block_mapping) -> ()"); "isa) -> ()",
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks); &cpu_attn_reshape_and_cache);
ops.def(
// Reshape the key and value tensors and cache them. "cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
cache_ops.def( "value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
"reshape_and_cache(Tensor key, Tensor value," "seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
" Tensor! key_cache, Tensor! value_cache," "sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
" Tensor slot_mapping," "float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()",
" str kv_cache_dtype," &cpu_attention_with_kv_cache);
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# VLLM_CPU_DISABLE_AVX512=false (default)|true # VLLM_CPU_DISABLE_AVX512=false (default)|true
# VLLM_CPU_AVX512BF16=false (default)|true # VLLM_CPU_AVX512BF16=false (default)|true
# VLLM_CPU_AVX512VNNI=false (default)|true # VLLM_CPU_AVX512VNNI=false (default)|true
# VLLM_CPU_AMXBF16=false (default)|true
# #
######################### COMMON BASE IMAGE ######################### ######################### COMMON BASE IMAGE #########################
...@@ -92,6 +93,9 @@ ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16} ...@@ -92,6 +93,9 @@ ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
# Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ... # Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ...
ARG VLLM_CPU_AVX512VNNI=0 ARG VLLM_CPU_AVX512VNNI=0
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI} ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
# Support for building with AMXBF16 ISA: docker build --build-arg VLLM_CPU_AMXBF16="true" ...
ARG VLLM_CPU_AMXBF16=0
ENV VLLM_CPU_AMXBF16=${VLLM_CPU_AMXBF16}
WORKDIR /workspace/vllm WORKDIR /workspace/vllm
......
...@@ -171,6 +171,8 @@ This value is 4GB by default. Larger space can support more concurrent requests, ...@@ -171,6 +171,8 @@ This value is 4GB by default. Larger space can support more concurrent requests,
First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
Use multiples of 32 as `--block-size`, which is 128 by default.
Inference batch size is an important parameter for the performance. A larger batch usually provides higher throughput, a smaller batch provides lower latency. Tuning the max batch size starting from the default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: Inference batch size is an important parameter for the performance. A larger batch usually provides higher throughput, a smaller batch provides lower latency. Tuning the max batch size starting from the default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment