Unverified Commit 7f829be7 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Refactor CPU attention backend (#27954)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent e1710393
......@@ -132,7 +132,7 @@ steps:
queue: cpu_queue_postmerge
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest"
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
env:
......
......@@ -49,6 +49,7 @@ function cpu_tests() {
# Run kernel tests
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/test_onednn.py"
# Run basic model test
......@@ -116,4 +117,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 40 mins.
export -f cpu_tests
timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
timeout 2.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
......@@ -15,6 +15,7 @@ endif()
#
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
include_directories("${CMAKE_SOURCE_DIR}/csrc")
......@@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
endif()
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
set(ENABLE_AMXBF16 ON)
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
endif()
elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
......@@ -275,7 +292,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_VERBOSE "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
FetchContent_MakeAvailable(oneDNN)
set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE})
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
target_include_directories(
dnnl_ext
......@@ -305,14 +325,14 @@ endif()
#
set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/scratchpad_manager.cpp"
"csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
......
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
struct KernelVecType {
using q_load_vec_type = void;
using q_vec_type = void;
using k_load_vec_type = void;
using k_vec_type = void;
using qk_acc_vec_type = void;
using v_load_vec_type = void;
};
template <>
struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
};
template <>
struct KernelVecType<c10::Half> {
#if defined(__powerpc64__) || defined(__s390x__)
// Power and s390x architecture-specific vector types
using q_load_vec_type = vec_op::FP32Vec8;
using k_load_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
#else
// Fallback for other architectures, including x86
using q_load_vec_type = vec_op::FP16Vec8;
using k_load_vec_type = vec_op::FP16Vec16;
using v_load_vec_type = vec_op::FP16Vec16;
#endif
using q_vec_type = vec_op::FP32Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
};
#ifdef __AVX512BF16__
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32;
using k_vec_type = vec_op::BF16Vec32;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#else
#ifdef __aarch64__
#ifndef ARM_BF16_SUPPORT
// pass
#else
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
#else
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
#endif
template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
const int capacity) {
T max = data[0];
for (int i = 1; i < size; ++i) {
max = max >= data[i] ? max : data[i];
}
T sum = 0;
for (int i = 0; i < size; ++i) {
data[i] = std::exp(data[i] - max);
sum += data[i];
}
int i = 0;
for (; i < size; ++i) {
data[i] /= sum;
}
for (; i < capacity; ++i) {
data[i] = 0;
}
return {max, sum};
}
template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
const int capacity,
const float alibi_slope,
const int start_index,
const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0];
for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
data[i] = qk;
max = max >= qk ? max : qk;
}
T sum = 0;
for (int i = 0; i < size; ++i) {
data[i] = std::exp(data[i] - max);
sum += data[i];
}
int i = 0;
for (; i < size; ++i) {
data[i] /= sum;
}
for (; i < capacity; ++i) {
data[i] = 0;
}
return {max, sum};
}
template <typename T>
FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data,
const int size) {
T max = max_data[0];
for (int i = 1; i < size; ++i) {
max = max >= max_data[i] ? max : max_data[i];
}
T rescaled_sum = 0;
for (int i = 0; i < size; ++i) {
T rescale_factor = std::exp(max_data[i] - max);
rescaled_sum += rescale_factor * sum_data[i];
sum_data[i] *= rescale_factor;
}
for (int i = 0; i < size; ++i) {
sum_data[i] /= rescaled_sum + 1e-8;
}
}
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
struct reduceQKBlockKernel {
using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t* __restrict__ q,
const scalar_t* __restrict__ k_block,
float* __restrict__ logits, float scale,
const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
qk_acc_vec_type group_accums[MAX_GROUP_NUM];
if (token_num == BLOCK_SIZE) {
for (int q_offset = 0; q_offset < HEAD_SIZE;
q_offset += x, k_block += x * BLOCK_SIZE) {
q_load_vec_type q_load_group_vec(q + q_offset);
q_vec_type q_group_vec(q_load_group_vec);
vec_op::unroll_loop<int, MAX_GROUP_NUM>(
[k_block, &q_group_vec, &group_accums](int token_group_idx) {
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
TOKEN_PER_GROUP);
k_vec_type k_group_vec(k_load_group_vec);
vec_op::fma(group_accums[token_group_idx], q_group_vec,
k_group_vec);
vec_op::prefetch(k_block + x * BLOCK_SIZE +
token_group_idx * x * TOKEN_PER_GROUP);
});
}
} else {
for (int q_offset = 0; q_offset < HEAD_SIZE;
q_offset += x, k_block += x * BLOCK_SIZE) {
q_load_vec_type q_load_group_vec(q + q_offset);
q_vec_type q_group_vec(q_load_group_vec);
for (int token_group_start = 0; token_group_start < group_num;
token_group_start += UNROLL_GROUP_NUM) {
vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
[token_group_start, k_block, &q_group_vec,
&group_accums](int token_group_idx) {
token_group_idx += token_group_start;
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
TOKEN_PER_GROUP);
k_vec_type k_group_vec(k_load_group_vec);
vec_op::fma(group_accums[token_group_idx], q_group_vec,
k_group_vec);
vec_op::prefetch(k_block + x * BLOCK_SIZE +
token_group_idx * x * TOKEN_PER_GROUP);
});
}
}
}
for (int token_group_idx = 0; token_group_idx < group_num;
++token_group_idx) {
vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
[&group_accums, logits, scale, token_group_idx](int token_idx) {
float dot_v =
group_accums[token_group_idx]
.template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
TOKEN_PER_GROUP>(token_idx);
logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
dot_v * scale;
});
}
}
};
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
acc_t&& acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM);
vec_op::FP32Vec16 prob_vec(prob);
vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
vec_op::FP32Vec16 fp32_v_vec(v_vec);
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
});
}
}; // namespace
// Paged attention v1
namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl {
static void call(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs,
// max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads;
static_assert(BLOCK_SIZE == 16);
int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes =
parallel_work_item_num * max_seq_len_padded * sizeof(float);
float* logits = (float*)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int seq_len = seq_lens[seq_idx];
const int* seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
float* __restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float* __restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
}
// Compute softmax
if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, seq_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
seq_len);
} else {
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
}
// Compute value
constexpr int head_elem_num_per_partition = 16;
constexpr int head_partition_num =
HEAD_SIZE / head_elem_num_per_partition;
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t* __restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float* __restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
head_elem_num_per_partition>(
prob_vec_ptr, v_block_cache_ptr, accums);
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
if (head_elem_idx % 2 == 0) {
vec_op::prefetch(next_v_block_cache_ptr +
BLOCK_SIZE * head_elem_idx);
}
});
}
}
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
float value = accums[head_elem_idx].reduce_sum();
vec_op::storeFP32(value, out_ptr + head_elem_idx);
});
}
}
}
std::free(logits);
}
};
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 32:
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 192:
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break;
case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
});
}
// Paged attention v2
namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl {
static void call(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs,
// max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads;
static_assert(BLOCK_SIZE == 16);
static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
#pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= seq_len) continue;
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1);
const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx);
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num =
token_num - (block_num - 1) * BLOCK_SIZE;
const int* seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float* __restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
}
std::pair<float, float> max_and_sum;
if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi(
logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len);
} else {
max_and_sum =
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
}
auto&& [max_logit, exp_sum] = max_and_sum;
scalar_t* __restrict__ output_buffer = nullptr;
if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
max_logits[idx] = max_logit;
exp_sums[idx] = exp_sum;
output_buffer =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
} else {
output_buffer =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
}
// Compute value
constexpr int head_elem_num_per_partition = 16;
constexpr int head_partition_num =
HEAD_SIZE / head_elem_num_per_partition;
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t* __restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float* __restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE;
const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
head_elem_num_per_partition>(
prob_vec_ptr, v_block_cache_ptr, accums);
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
if (head_elem_idx % 2 == 0) {
vec_op::prefetch(next_v_block_cache_ptr +
BLOCK_SIZE * head_elem_idx);
}
});
}
}
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
float value = accums[head_elem_idx].reduce_sum();
vec_op::storeFP32(value, out_ptr + head_elem_idx);
});
}
}
}
}
// Rescale partition softmax and store the factors to exp_sums
#pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int seq_len = seq_lens[seq_idx];
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) continue;
reducePartitionSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions,
exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions,
partition_num);
}
}
// Reduce values
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some
// HEAD_SIZE didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float* __restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int seq_len = seq_lens[seq_idx];
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) continue;
const float* __restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
const scalar_t* __restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group;
scalar_t* __restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group;
vec_op::FP32Vec16 acc;
for (int i = 0; i < partition_num; ++i) {
vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
vec_op::FP32Vec16 fp32_value(value);
acc = acc + fp32_value * rescale_factor;
}
v_load_vec_type cast_acc(acc);
cast_acc.save(seq_head_output);
}
}
}
}
};
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const std::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 32:
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 192:
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break;
case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
});
}
\ No newline at end of file
#include <map>
#include <vector>
#include "cpu_types.hpp"
#if defined(__x86_64__)
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif
namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& mapping_pairs,
const int element_num_per_block,
const int layer_num) {
const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset =
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t* target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
}
}
}
template <typename scalar_t>
void reshape_and_cache_cpu_impl(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size;
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx >= 0) {
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx =
token_idx * value_stride + head_idx * head_size;
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
const int64_t target_offset =
src_key_idx * block_size + block_offset * x;
for (int i = 0; i < x; ++i) {
target_key_head_ptr[target_offset + i] =
src_key_head_ptr[src_key_idx + i];
}
}
for (int src_value_idx = 0; src_value_idx < head_size;
++src_value_idx) {
const int64_t target_offset =
src_value_idx * block_size + block_offset;
target_value_head_ptr[target_offset] =
src_value_head_ptr[src_value_idx];
}
}
}
}
}
}; // namespace
template <typename scalar_t>
void concat_and_cache_mla_cpu_impl(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int num_tokens, //
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size //
) {
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
continue;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src,
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
int size, int offset) {
for (int i = 0; i < size; i++) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
dst[dst_idx] = src[src_idx];
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
}
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
const int element_num_per_block = key_caches[0][0].numel();
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
}
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
}
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
TORCH_CHECK(kv_cache_dtype != "fp8");
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
VLLM_DISPATCH_FLOATING_TYPES(
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
concat_and_cache_mla_cpu_impl<scalar_t>(
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
kv_lora_rank, pe_dim, block_size);
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
});
}
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}
#include "cpu_attn_vec.hpp"
#include "cpu_attn_vec16.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu_attn_amx.hpp"
#define AMX_DISPATCH(...) \
case cpu_attention::ISA::AMX: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::AMX, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
return __VA_ARGS__(); \
}
#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \
[&] { \
switch (HEAD_DIM) { \
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \
default: { \
TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \
std::to_string(HEAD_DIM)); \
} \
} \
}()
#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
case cpu_attention::ISA::VEC16: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC16, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
default: { \
TORCH_CHECK(false, "Invalid CPU attention ISA type."); \
} \
} \
}()
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
const torch::Tensor& seq_lens, at::ScalarType dtype,
const torch::Tensor& query_start_loc, const bool casual,
const int64_t window_size, const std::string& isa_hint,
const bool enable_kv_split) {
cpu_attention::ISA isa;
if (isa_hint == "amx") {
isa = cpu_attention::ISA::AMX;
} else if (isa_hint == "vec") {
isa = cpu_attention::ISA::VEC;
} else if (isa_hint == "vec16") {
isa = cpu_attention::ISA::VEC16;
} else {
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
}
cpu_attention::AttentionScheduler::ScheduleInput input;
input.num_reqs = num_req;
input.num_heads_q = num_heads_q;
input.num_heads_kv = num_heads_kv;
input.head_dim = head_dim;
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
input.seq_lens = seq_lens.data_ptr<int32_t>();
if (window_size != -1) {
input.left_sliding_window_size = window_size - 1;
if (casual) {
input.right_sliding_window_size = 0;
} else {
input.right_sliding_window_size = window_size - 1;
}
} else {
input.left_sliding_window_size = -1;
if (casual) {
input.right_sliding_window_size = 0;
} else {
input.right_sliding_window_size = -1;
}
}
input.casual = casual;
input.isa = isa;
input.enable_kv_split = enable_kv_split;
TORCH_CHECK(casual, "Only supports casual mask for now.");
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
CPU_ATTN_DISPATCH_IMPL(isa, [&]() {
input.elem_size = sizeof(scalar_t);
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
input.output_buffer_elem_size =
sizeof(attn_impl::partial_output_buffer_t);
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
});
});
});
cpu_attention::AttentionScheduler scheduler;
torch::Tensor metadata = scheduler.schedule(input);
return metadata;
}
void cpu_attn_reshape_and_cache(
const torch::Tensor& key, // [token_num, head_num, head_size]
const torch::Tensor& value, // [token_num, head_num, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor& slot_mapping, const std::string& isa) {
TORCH_CHECK_EQ(key.dim(), 3);
TORCH_CHECK_EQ(value.dim(), 3);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
TORCH_CHECK_EQ(key.stride(2), 1);
TORCH_CHECK_EQ(value.stride(2), 1);
const int64_t token_num = key.size(0);
const int64_t key_token_num_stride = key.stride(0);
const int64_t value_token_num_stride = value.stride(0);
const int64_t head_num = value.size(1);
const int64_t key_head_num_stride = key.stride(1);
const int64_t value_head_num_stride = value.stride(1);
const int64_t num_blocks = key_cache.size(0);
const int64_t num_blocks_stride = key_cache.stride(0);
const int64_t cache_head_num_stride = key_cache.stride(1);
const int64_t block_size = key_cache.size(2);
const int64_t block_size_stride = key_cache.stride(2);
const int64_t head_dim = key.size(-1);
cpu_attention::ISA isa_tag = [&]() {
if (isa == "amx") {
return cpu_attention::ISA::AMX;
} else if (isa == "vec") {
return cpu_attention::ISA::VEC;
} else if (isa == "vec16") {
return cpu_attention::ISA::VEC16;
} else {
TORCH_CHECK(false, "Invalid ISA type: " + isa);
}
}();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() {
attn_impl::reshape_and_cache(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), token_num,
key_token_num_stride, value_token_num_stride, head_num,
key_head_num_stride, value_head_num_stride, num_blocks,
num_blocks_stride, cache_head_num_stride, block_size,
block_size_stride);
});
});
});
}
void cpu_attention_with_kv_cache(
const torch::Tensor& query, // [num_tokens, num_heads, head_size]
const torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& output, // [num_tokens, num_heads, head_size]
const torch::Tensor& query_start_loc, // [num_tokens + 1]
const torch::Tensor& seq_lens, // [num_tokens]
const double scale, const bool causal,
const std::optional<torch::Tensor>& alibi_slopes, // [num_heads]
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, // [num_tokens, max_block_num]
const double softcap, const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux // [num_heads]
) {
TORCH_CHECK_EQ(query.dim(), 3);
TORCH_CHECK_EQ(query.stride(2), 1);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
cpu_attention::AttentionInput input;
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
scheduler_metadata.data_ptr());
input.num_tokens = query.size(0);
input.num_heads = query.size(1);
input.num_kv_heads = key_cache.size(1);
input.block_size = key_cache.size(2);
input.query = query.data_ptr();
input.query_num_tokens_stride = query.stride(0);
input.query_num_heads_stride = query.stride(1);
input.cache_num_blocks_stride = key_cache.stride(0);
input.cache_num_kv_heads_stride = key_cache.stride(1);
input.blt_num_tokens_stride = block_table.stride(0);
input.key_cache = key_cache.data_ptr();
input.value_cache = value_cache.data_ptr();
input.output = output.data_ptr();
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
input.seq_lens = seq_lens.data_ptr<int32_t>();
input.block_table = block_table.data_ptr<int32_t>();
input.alibi_slopes =
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
// For now sink must be bf16
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
input.scale = scale;
input.causal = causal;
input.sliding_window_left = sliding_window_left;
input.sliding_window_right = sliding_window_right;
if (input.causal) {
// to make boundary calculation easier
input.sliding_window_right = 0;
}
float softcap_fp32 = softcap;
input.softcap = softcap_fp32;
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] {
CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() {
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
mainloop(&input);
});
});
});
}
#ifndef CPU_ATTN_AMX_HPP
#define CPU_ATTN_AMX_HPP
#include "cpu_attn_impl.hpp"
namespace cpu_attention {
namespace {
// AMX specific
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
typedef struct __tile_config {
uint8_t palette_id = 1;
uint8_t start_row = 0;
uint8_t reserved_0[14] = {0};
uint16_t colsb[16] = {0};
uint8_t rows[16] = {0};
} __tilecfg;
// 2-2-4 pattern, for 16 < m <= 32
// TILE 0, 1: load A matrix, row num should be 16, m - 16
// TILE 2, 3: load B matrix, row num should be 16
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
// - 16
template <typename kv_cache_t>
class TileGemm224 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
void* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
}
};
template <>
class TileGemm224<c10::BFloat16> {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
c10::BFloat16* __restrict__ a_tile,
c10::BFloat16* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
const int32_t k_times =
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM;
const int64_t a_tile_stride = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return AMX_TILE_ROW_BYTES;
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return lda * sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// k_cache is prepacked
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// v_cache is prepacked
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
// k_cache, v_cache are prepacked
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
// logits_buffer, output_buffer are not prepacked
float* __restrict__ c_tile_4 = c_tile;
float* __restrict__ c_tile_5 =
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc;
float* __restrict__ c_tile_7 =
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
const int32_t c_tile_stride = ldc * sizeof(float);
if (accum_c) {
_tile_loadd(4, c_tile_4, c_tile_stride);
_tile_loadd(5, c_tile_5, c_tile_stride);
_tile_loadd(6, c_tile_6, c_tile_stride);
_tile_loadd(7, c_tile_7, c_tile_stride);
} else {
_tile_zero(4);
_tile_zero(5);
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
_tile_dpbf16ps(4, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
_tile_dpbf16ps(5, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_dpbf16ps(6, 1, 2);
_tile_dpbf16ps(7, 1, 3);
// update ptrs
if constexpr (phase == AttentionGemmPhase::QK) {
// Q buffer is prepacked
a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// P buffer is not prepacked
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
_tile_stored(4, c_tile_4, c_tile_stride);
_tile_stored(5, c_tile_5, c_tile_stride);
_tile_stored(6, c_tile_6, c_tile_stride);
_tile_stored(7, c_tile_7, c_tile_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
const int32_t m_0 = AMX_TILE_ROW_NUM;
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
config.rows[0] = m_0;
config.rows[1] = m_1;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = m_0;
config.rows[5] = m_0;
config.rows[6] = m_1;
config.rows[7] = m_1;
_tile_loadconfig(&config);
}
};
// 1-2-2 pattern, for 0 < m <= 16
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
// m, m
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
// num should be 16
// TILE 6, 7, (6, 7): store results C matrix, row num should be
// m
template <typename kv_cache_t>
class TileGemm122 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
void* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
}
};
template <>
class TileGemm122<c10::BFloat16> {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
c10::BFloat16* __restrict__ a_tile,
c10::BFloat16* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
c10::BFloat16* __restrict__ a_tile_1 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
const int64_t a_tile_stride = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return AMX_TILE_ROW_BYTES;
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return lda * sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// k_cache is prepacked
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// v_cache is prepacked
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_4 =
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
c10::BFloat16* __restrict__ b_tile_5 =
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
int64_t b_stride = AMX_TILE_ROW_BYTES;
float* __restrict__ c_tile_6 = c_tile;
float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float);
int64_t c_stride = ldc * sizeof(float);
const int32_t k_times =
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
const int32_t k_group_times = k_times / 2;
const bool has_tail = (k_times % 2 == 1);
if (accum_c) {
_tile_loadd(6, c_tile_6, c_stride);
_tile_loadd(7, c_tile_7, c_stride);
} else {
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_group_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_stream_loadd(4, b_tile_4, b_stride);
_tile_dpbf16ps(6, 1, 4);
_tile_stream_loadd(5, b_tile_5, b_stride);
_tile_dpbf16ps(7, 1, 5);
// update ptrs
if constexpr (phase == AttentionGemmPhase::QK) {
// Q buffer is prepacked
a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// P buffer is not prepacked
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
}
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
if (has_tail) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
}
_tile_stored(6, c_tile_6, c_stride);
_tile_stored(7, c_tile_7, c_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
config.rows[0] = m;
config.rows[1] = m;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = AMX_TILE_ROW_NUM;
config.rows[5] = AMX_TILE_ROW_NUM;
config.rows[6] = m;
config.rows[7] = m;
_tile_loadconfig(&config);
}
};
} // namespace
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = scalar_t;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = scalar_t;
constexpr static int64_t BlockSizeAlignment =
AMX_TILE_ROW_BYTES /
sizeof(kv_cache_t); // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 32;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::AMX;
constexpr static bool scale_on_logits = true;
public:
AttentionImpl() : current_q_head_num_(0) {
// Use all columns in AMX tiles
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
}
~AttentionImpl() { _tile_release(); }
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
if (q_head_num > AMX_TILE_ROW_NUM) {
if (q_head_num != current_q_head_num_) {
current_q_head_num_ = q_head_num;
TileGemm224<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
}
attention<TileGemm224<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
} else {
if (q_head_num != current_q_head_num_) {
current_q_head_num_ = q_head_num;
TileGemm122<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
}
attention<TileGemm122<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment * head_dim;
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment * (AMX_TILE_ROW_BYTES / 4);
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return block_size * HeadDimAlignment;
}
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
scalar_t* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, const float scale) {
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
constexpr int64_t head_elem_num_pre_block =
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
int32_t idx = 0;
int8_t* __restrict__ q_buffer_iter = reinterpret_cast<int8_t*>(q_buffer);
for (int32_t q_num_idx = 0; q_num_idx < q_num;
++q_num_idx, src += q_num_stride) {
scalar_t* __restrict__ src_iter = src;
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv;
++q_head_idx, src_iter += q_head_stride) {
vec_op::unroll_loop<int32_t, head_size_block_num>(
[&](int32_t head_size_block_idx) {
// Use INT8Vec64 for 64 bytes block
vec_op::INT8Vec64 vec(src_iter + head_size_block_idx *
head_elem_num_pre_block);
vec.save(q_buffer_iter + head_size_block_idx * AMX_TILE_BYTES);
});
++idx;
q_buffer_iter += AMX_TILE_ROW_BYTES;
if ((idx & (AMX_TILE_ROW_NUM - 1)) == 0) {
// head is in another amx tile
q_buffer_iter -= AMX_TILE_ROW_NUM * AMX_TILE_ROW_BYTES;
q_buffer_iter += head_size_block_num * AMX_TILE_BYTES;
}
}
}
}
// reshape KV to AMX friendly layout
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
// For AMX 2D tiles, size of each line is 64 bytes
constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES;
// For AMX B martix, N always is 16
constexpr int64_t amx_b_tile_n_size = AMX_TILE_ROW_BYTES / 4;
constexpr int64_t amx_b_tile_k_size = amx_tile_row_size / sizeof(scalar_t);
// For now suppose block_size is divisible by amx_tile_column_num
TORCH_CHECK_EQ(block_size % amx_b_tile_k_size, 0);
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key
// Head elements should be packed as quand-words and stored in token
// groups with (quadword_stride/4) tokens
constexpr int64_t token_num_per_group = amx_tile_row_size / 4;
static_assert(head_dim % (4 / sizeof(scalar_t)) == 0);
constexpr int64_t quadword_num = head_dim / (4 / sizeof(scalar_t));
const int32_t* key_start_quadword_ptr =
reinterpret_cast<const int32_t*>(
key + token_idx * key_token_num_stride +
head_idx * key_head_num_stride);
const int64_t group_idx = block_offset / token_num_per_group;
const int64_t group_offset = block_offset % token_num_per_group;
constexpr int64_t quadword_num_per_group =
token_num_per_group * quadword_num;
int32_t* key_cache_start_ptr =
reinterpret_cast<int32_t*>(key_cache +
block_idx * num_blocks_stride +
head_idx * cache_head_num_stride) +
group_idx * quadword_num_per_group + group_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; j < quadword_num;
i += token_num_per_group, ++j) {
key_cache_start_ptr[i] = key_start_quadword_ptr[j];
}
}
{
// Write Value
// Different from Key, block_size dimension is packed rather than
// head_size dimension block_size dimension is packed as quand-words;
constexpr int64_t token_num_per_sub_group = 4 / sizeof(scalar_t);
const int64_t token_num_per_group = block_size;
constexpr int64_t head_elems_per_group = amx_b_tile_n_size;
const int64_t group_size = token_num_per_group * head_elems_per_group;
// For now suppose head_dim is divisible by amx_b_tile_n_size
static_assert(head_dim % head_elems_per_group == 0);
constexpr int64_t group_num = head_dim / head_elems_per_group;
const int64_t sub_group_idx = block_offset / token_num_per_sub_group;
const int64_t sub_group_offset =
block_offset % token_num_per_sub_group;
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride +
sub_group_idx * token_num_per_sub_group * amx_b_tile_n_size +
sub_group_offset;
for (int64_t i = 0; i < group_num; ++i) {
#pragma GCC unroll head_elems_per_group
for (int64_t j = 0, k = 0; j < head_elems_per_group;
++j, k += token_num_per_sub_group) {
value_cache_start_ptr[k] = value_start_ptr[j];
}
value_start_ptr += head_elems_per_group;
value_cache_start_ptr += group_size;
}
}
}
}
}
private:
alignas(64) __tilecfg amx_tile_config_;
int32_t current_q_head_num_;
};
} // namespace cpu_attention
#endif
#ifndef CPU_ATTN_HPP
#define CPU_ATTN_HPP
#include <unistd.h>
#include <type_traits>
#include <cstddef>
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "cpu_attn_macros.h"
namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16 };
template <ISA isa, typename scalar_t, int64_t head_dim>
class AttentionImpl {};
struct AttentionWorkItemGroup {
int32_t req_id;
int32_t q_token_id_start;
int32_t q_token_num;
int32_t kv_split_pos_start;
int32_t kv_split_pos_end;
int64_t total_kv_len;
int32_t split_id;
int32_t local_split_id;
AttentionWorkItemGroup(const int32_t req_id, const int32_t q_token_id_start,
const int32_t kv_split_pos_start,
const int32_t kv_split_pos_end)
: req_id(req_id),
q_token_id_start(q_token_id_start),
q_token_num(0),
kv_split_pos_start(kv_split_pos_start),
kv_split_pos_end(kv_split_pos_end),
total_kv_len(0),
split_id(-1),
local_split_id(0) {}
std::string to_string() const {
std::stringstream ss;
ss << '[' << "req_id: " << req_id << ",\n";
ss << "q_token_id_start: " << q_token_id_start << ",\n";
ss << "q_token_num: " << q_token_num << ",\n";
ss << "kv_split_pos_start: " << kv_split_pos_start << ",\n";
ss << "kv_split_pos_end: " << kv_split_pos_end << ",\n";
ss << "total_kv_len: " << total_kv_len << ",\n";
ss << "split_id: " << split_id << ",\n";
ss << "local_split_id: " << local_split_id << ",\n";
ss << ']';
return ss.str();
}
};
struct ReductionWorkItemGroup {
int32_t req_id;
int32_t q_token_id_start;
int32_t q_token_id_num;
int32_t split_start_id;
int32_t split_num;
ReductionWorkItemGroup(const int32_t req_id, const int32_t q_token_id_start,
const int32_t q_token_id_num,
const int32_t split_start_id)
: req_id(req_id),
q_token_id_start(q_token_id_start),
q_token_id_num(q_token_id_num),
split_start_id(split_start_id),
split_num(0) {}
std::string to_string() const {
std::stringstream ss;
ss << '[' << "req_id: " << req_id << ",\n";
ss << "q_token_id_start: " << q_token_id_start << ",\n";
ss << "q_token_id_num: " << q_token_id_num << ",\n";
ss << "split_start_id: " << split_start_id << ",\n";
ss << "split_num: " << split_num << ",\n";
ss << ']';
return ss.str();
}
};
struct AttentionMetadata {
std::atomic_int64_t counter;
char _padding1[56];
ISA isa;
int32_t workitem_group_num;
int32_t reduction_item_num;
int32_t reduction_split_num;
int32_t thread_num;
int32_t effective_thread_num; // non-zero item num in workitem_num_per_thread
int32_t split_kv_q_token_num_threshold;
int64_t attention_scratchpad_size_per_thread;
int64_t reduction_scratchpad_size_per_kv_head;
AttentionWorkItemGroup* workitem_groups_ptr;
ReductionWorkItemGroup* reduction_items_ptr;
int32_t cu_workitem_num_per_thread[1025] = {
0}; // prefix sum of workitem_num_per_thread
char _padding2[56];
AttentionMetadata(ISA isa, int32_t workitem_group_num,
int32_t reduction_item_num, int32_t reduction_split_num,
int32_t split_kv_q_token_num_threshold)
: isa(isa),
workitem_group_num(workitem_group_num),
reduction_item_num(reduction_item_num),
reduction_split_num(reduction_split_num),
thread_num(omp_get_max_threads()),
effective_thread_num(thread_num),
split_kv_q_token_num_threshold(split_kv_q_token_num_threshold),
attention_scratchpad_size_per_thread(0),
reduction_scratchpad_size_per_kv_head(0),
workitem_groups_ptr(
(AttentionWorkItemGroup*)((char*)this + sizeof(AttentionMetadata))),
reduction_items_ptr(
(ReductionWorkItemGroup*)((char*)this + sizeof(AttentionMetadata) +
workitem_group_num *
sizeof(AttentionWorkItemGroup))),
counter(0) {
TORCH_CHECK_LE(thread_num, 1024);
static_assert(sizeof(AttentionMetadata) % 64 == 0);
TORCH_CHECK(reinterpret_cast<size_t>(this) % 64 == 0);
}
void reset_counter() { counter.store(0); }
int64_t acquire_counter() { return counter++; }
void print() const {
std::stringstream ss;
ss << "ISA: ";
switch (isa) {
case ISA::AMX:
ss << "AMX, ";
break;
case ISA::VEC:
ss << "VEC, ";
break;
}
ss << "workitem_group_num: " << workitem_group_num
<< ", reduction_item_num: " << reduction_item_num
<< ", reduction_split_num: " << reduction_split_num
<< ", thread_num: " << thread_num
<< ", effective_thread_num: " << effective_thread_num
<< ", attention_scratchpad_size_per_thread: "
<< attention_scratchpad_size_per_thread
<< ", reduction_scratchpad_size_per_kv_head: "
<< reduction_scratchpad_size_per_kv_head << ", workitem groups:\n";
for (int32_t i = 0; i < workitem_group_num; ++i) {
ss << (workitem_groups_ptr + i)->to_string() << ",\n";
}
ss << "cu_workitem_num_per_thread: [";
for (int32_t i = 0; i < thread_num + 1; ++i) {
ss << cu_workitem_num_per_thread[i] << ", ";
}
ss << "]\n";
ss << "reduction items: \n";
for (int32_t i = 0; i < reduction_item_num; ++i) {
ss << (reduction_items_ptr + i)->to_string() << ",\n";
}
std::printf("%s", ss.str().c_str());
}
};
// Thread attention scratchpad contains:
// - Q: q_tile_size * head_dim * q_buffer_elem_size, gather Q heads, especially
// for GQA
// - Q@K^T: max_num_q_per_iter * k_tile_size * logits_buffer_elem_size, logits
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
// * q_tile_size * 4, partial output, max + sum (float)
// Reduction scratchpad contains:
// - flags: bool array to indicate wether the split is finished
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
// - max, sum: 2 * split_num * q_tile_size * 4
class AttentionScratchPad {
public:
AttentionScratchPad(int64_t thread_id,
const AttentionMetadata& attention_metadata,
void* scratchpad_ptr)
: thread_scratchpad_ptr(
static_cast<int8_t*>(scratchpad_ptr) +
thread_id *
attention_metadata.attention_scratchpad_size_per_thread),
reduction_scratchpad_ptr(
static_cast<int8_t*>(scratchpad_ptr) +
attention_metadata.thread_num *
attention_metadata.attention_scratchpad_size_per_thread),
reduction_scratchpad_size_per_kv_head(
attention_metadata.reduction_scratchpad_size_per_kv_head) {}
// for attention
void update(const int64_t head_dim, const int64_t q_buffer_elem_size,
const int64_t logits_buffer_elem_size,
const int64_t output_buffer_elem_size,
const int64_t max_num_q_per_iter, const int64_t q_head_tile_size,
const int64_t kv_tile_size) {
int64_t buffer_offset = 0;
q_buffer_offset_ = buffer_offset;
buffer_offset +=
calcu_q_buffer_size(q_head_tile_size, head_dim, q_buffer_elem_size);
logits_buffer_offset_ = buffer_offset;
buffer_offset += calcu_logits_buffer_size(max_num_q_per_iter, kv_tile_size,
logits_buffer_elem_size);
output_buffer_offset_ = buffer_offset;
buffer_offset += calcu_partial_output_buffer_size(
q_head_tile_size, head_dim, output_buffer_elem_size);
max_buffer_offset_ = buffer_offset;
buffer_offset += calcu_partial_output_max_sum_buffer_size(q_head_tile_size);
sum_buffer_offset_ = buffer_offset;
}
// for reduction
void update(const int32_t kv_head_idx, const int32_t total_split_num,
const int64_t head_dim, const int64_t q_head_tile_size,
const int64_t output_buffer_elem_size) {
int64_t buffer_offset = kv_head_idx * reduction_scratchpad_size_per_kv_head;
reduce_flag_buffer_offset_ = buffer_offset;
buffer_offset += calcu_reduce_flag_buffer_size(total_split_num);
reduce_output_buffer_offset_ = buffer_offset;
buffer_offset += calcu_reduce_output_buffer_size(
total_split_num, q_head_tile_size, head_dim, output_buffer_elem_size);
reduce_max_buffer_offset_ = buffer_offset;
buffer_offset +=
calcu_reduce_max_sum_buffer_size(total_split_num, q_head_tile_size);
reduce_sum_buffer_offset_ = buffer_offset;
}
template <typename T>
T* get_q_buffer() {
return reinterpret_cast<T*>(thread_scratchpad_ptr + q_buffer_offset_);
}
float* get_logits_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr +
logits_buffer_offset_);
}
float* get_output_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr +
output_buffer_offset_);
}
float* get_max_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr + max_buffer_offset_);
}
float* get_sum_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr + sum_buffer_offset_);
}
volatile bool* get_reduce_flag_buffer() {
return reinterpret_cast<volatile bool*>(reduction_scratchpad_ptr +
reduce_flag_buffer_offset_);
}
float* get_reduce_output_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_output_buffer_offset_);
}
float* get_reduce_max_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_max_buffer_offset_);
}
float* get_reduce_sum_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_sum_buffer_offset_);
}
int64_t get_thread_scratchpad_size() const {
return 2 * sum_buffer_offset_ - max_buffer_offset_;
}
int64_t get_reduction_scratchpad_size() const {
return 2 * reduce_sum_buffer_offset_ - reduce_max_buffer_offset_;
}
private:
static int64_t round_to_64(const int64_t num) {
return ((num + 63) >> 6) << 6;
}
static int64_t calcu_q_buffer_size(const int64_t q_tile_size,
const int64_t head_dim,
const int64_t elem_size) {
return round_to_64(q_tile_size * head_dim * elem_size);
}
static int64_t calcu_logits_buffer_size(const int64_t max_num_q_per_iter,
const int64_t k_tile_size,
const int64_t elem_size) {
return round_to_64(elem_size * max_num_q_per_iter * k_tile_size);
}
static int64_t calcu_partial_output_buffer_size(const int64_t q_tile_size,
const int64_t head_dim,
const int64_t elem_size) {
return round_to_64(q_tile_size * head_dim * elem_size);
}
static int64_t calcu_partial_output_max_sum_buffer_size(
const int64_t q_tile_size) {
return round_to_64(q_tile_size * sizeof(float));
}
static int64_t calcu_reduce_flag_buffer_size(const int64_t total_split_num) {
return round_to_64(total_split_num * sizeof(bool));
}
static int64_t calcu_reduce_max_sum_buffer_size(
const int64_t total_split_num, const int32_t q_head_tile_size) {
return round_to_64(total_split_num * q_head_tile_size * sizeof(float));
}
static int64_t calcu_reduce_output_buffer_size(
const int64_t total_split_num, const int64_t q_head_tile_size,
const int64_t head_dim, const int64_t output_buffer_elem_size) {
return round_to_64(total_split_num * q_head_tile_size * head_dim *
output_buffer_elem_size);
}
private:
int8_t* thread_scratchpad_ptr;
int8_t* reduction_scratchpad_ptr;
int64_t reduction_scratchpad_size_per_kv_head;
// attention buffers
int64_t q_buffer_offset_;
int64_t logits_buffer_offset_;
int64_t output_buffer_offset_;
int64_t max_buffer_offset_;
int64_t sum_buffer_offset_;
// reduction buffers
int64_t reduce_flag_buffer_offset_;
int64_t reduce_output_buffer_offset_;
int64_t reduce_max_buffer_offset_;
int64_t reduce_sum_buffer_offset_;
};
class AttentionScheduler {
public:
struct ScheduleInput {
int32_t num_reqs;
int32_t elem_size;
int32_t q_buffer_elem_size;
int32_t logits_buffer_elem_size;
int32_t output_buffer_elem_size;
int32_t num_heads_q;
int32_t num_heads_kv;
int32_t head_dim;
int32_t* query_start_loc;
int32_t* seq_lens;
int32_t left_sliding_window_size;
int32_t right_sliding_window_size;
bool casual;
cpu_attention::ISA isa;
int32_t max_num_q_per_iter; // max Q head num can be hold in registers
int32_t kv_block_alignment; // context length alignment requirement
bool enable_kv_split;
};
static constexpr int32_t MaxQTileIterNum = 128;
AttentionScheduler() : available_cache_size_(get_available_l2_size()) {}
torch::Tensor schedule(const ScheduleInput& input) const {
const bool casual = input.casual;
const int32_t thread_num = omp_get_max_threads();
const int64_t cache_size = get_available_l2_size();
const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
const int32_t kv_len_alignment = input.kv_block_alignment;
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
const bool use_gqa = (max_num_q_per_iter % q_head_per_kv == 0);
if (!use_gqa) {
q_head_per_kv = 1; // fallback to MHA
}
const int32_t min_split_kv_len =
((max_num_q_per_iter * 4 + kv_len_alignment - 1) / kv_len_alignment) *
kv_len_alignment;
const int32_t max_num_q_token_per_iter = max_num_q_per_iter / q_head_per_kv;
const int64_t default_tile_size = calcu_default_tile_size(
cache_size, input.head_dim, input.elem_size, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, max_num_q_per_iter);
const int32_t default_tile_token_num = default_tile_size / q_head_per_kv;
const int32_t split_kv_q_token_num_threshold =
input.enable_kv_split ? 1 : 0;
const int32_t left_sliding_window_size = input.left_sliding_window_size;
const int32_t right_sliding_window_size = input.right_sliding_window_size;
TORCH_CHECK_LE(split_kv_q_token_num_threshold * q_head_per_kv, 16);
// get total kv len
int64_t total_kv_len = 0;
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
const int32_t seq_len = input.seq_lens[req_id];
const int32_t q_token_num =
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
const int32_t q_start_pos = (casual ? (seq_len - q_token_num) : 0);
const int32_t kv_start_pos = 0;
const int32_t kv_end_pos = seq_len;
for (int32_t token_id = 0; token_id < q_token_num;
token_id += max_num_q_token_per_iter) {
const int32_t q_tile_token_num =
std::min(max_num_q_token_per_iter, q_token_num - token_id);
const int32_t q_tile_pos_left = q_start_pos + token_id;
const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num;
const auto [kv_tile_pos_left, kv_tile_pos_right] = calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_pos_left, q_tile_pos_right,
left_sliding_window_size, right_sliding_window_size);
const auto [aligned_kv_tile_pos_left, aligned_kv_tile_pos_right] =
align_kv_tile_pos(kv_tile_pos_left, kv_tile_pos_right,
kv_len_alignment);
int32_t curr_kv_len =
aligned_kv_tile_pos_right - aligned_kv_tile_pos_left;
total_kv_len += curr_kv_len;
}
}
const int64_t kv_len_per_thread =
(((total_kv_len / thread_num) + kv_len_alignment - 1) /
kv_len_alignment) *
kv_len_alignment * (use_gqa ? input.num_heads_kv : input.num_heads_q);
std::vector<AttentionWorkItemGroup> workitems;
std::vector<ReductionWorkItemGroup> reduce_workitems;
workitems.reserve(1024);
reduce_workitems.reserve(1024);
std::vector<int32_t> workitem_num_per_thread(thread_num, 0);
// split tasks
int32_t curr_thread_id = 0;
int64_t remaining_kv_len = kv_len_per_thread;
int32_t cum_split_num = 0;
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
const int32_t seq_len = input.seq_lens[req_id];
const int32_t q_token_num =
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
const int32_t q_start_pos = (casual ? (seq_len - q_token_num) : 0);
const int32_t kv_start_pos = 0;
const int32_t kv_end_pos = seq_len;
int32_t local_split_id = 0;
AttentionWorkItemGroup curr_workitem(req_id, 0, 0, seq_len);
for (int32_t token_id = 0; token_id < q_token_num;
token_id += max_num_q_token_per_iter) {
const int32_t q_tile_token_num =
std::min(max_num_q_token_per_iter, q_token_num - token_id);
const int32_t q_tile_pos_left = q_start_pos + token_id;
const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num;
const auto [kv_tile_pos_left, kv_tile_pos_right] = calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_pos_left, q_tile_pos_right,
left_sliding_window_size, right_sliding_window_size);
const auto [aligned_kv_tile_pos_left, aligned_kv_tile_pos_right] =
align_kv_tile_pos(kv_tile_pos_left, kv_tile_pos_right,
kv_len_alignment);
int32_t curr_kv_len =
aligned_kv_tile_pos_right - aligned_kv_tile_pos_left;
int32_t kv_token_pos_start = aligned_kv_tile_pos_left;
while (curr_kv_len > 0) {
if (curr_kv_len <= (remaining_kv_len + min_split_kv_len) ||
curr_thread_id == (thread_num - 1)) {
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += curr_kv_len;
remaining_kv_len -= curr_kv_len;
curr_kv_len = 0;
if (remaining_kv_len < 0) {
// stop to accept more workitems
remaining_kv_len -= min_split_kv_len;
}
if (curr_workitem.kv_split_pos_start != 0) {
// got a partial kv spilt, need to create a single workitem
curr_workitem.split_id = cum_split_num;
curr_workitem.local_split_id = local_split_id;
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
++reduce_workitems.back().split_num;
++cum_split_num;
curr_workitem = AttentionWorkItemGroup(
req_id, token_id + max_num_q_token_per_iter, 0, seq_len);
}
break;
}
if (remaining_kv_len < min_split_kv_len &&
(curr_workitem.total_kv_len > 0 ||
workitem_num_per_thread[curr_thread_id] > 0)) {
// remaining_kv_len is too short, and have allocated workitems, just
// leave to next thread
if (curr_workitem.total_kv_len > 0) {
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, 0, seq_len);
}
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
// retry this iteration
continue;
}
// only split tail splits with q_tile_token_num <=
// split_kv_q_token_num_threshold
if (token_id + max_num_q_token_per_iter < q_token_num ||
q_tile_token_num > split_kv_q_token_num_threshold) {
// if requires a new q tile iteration and already has workitems,
// leave this workitem to next thread
if (curr_workitem.q_token_num % default_tile_token_num == 0 &&
(curr_workitem.total_kv_len > 0 ||
workitem_num_per_thread[curr_thread_id] > 0)) {
if (curr_workitem.total_kv_len > 0) {
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, 0, seq_len);
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
}
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += curr_kv_len;
remaining_kv_len -= curr_kv_len;
curr_kv_len = 0;
break;
}
// split kv
if (curr_workitem.total_kv_len > 0) {
// write back curr workitem
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
if (kv_token_pos_start == aligned_kv_tile_pos_left) {
// first split, init the workitem
reduce_workitems.emplace_back(ReductionWorkItemGroup(
req_id, token_id, q_tile_token_num, cum_split_num));
}
int32_t spilt_size =
std::min(std::max(remaining_kv_len, (int64_t)min_split_kv_len),
(int64_t)curr_kv_len);
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, kv_token_pos_start,
kv_token_pos_start + spilt_size);
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += spilt_size;
curr_workitem.split_id = cum_split_num;
curr_workitem.local_split_id = local_split_id;
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
++reduce_workitems.back().split_num;
++cum_split_num;
++local_split_id;
kv_token_pos_start += spilt_size;
curr_kv_len -= spilt_size;
curr_workitem = AttentionWorkItemGroup(req_id, token_id,
kv_token_pos_start, seq_len);
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
}
}
if (curr_workitem.total_kv_len > 0) {
// write back curr workitem
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
}
int64_t metadata_tensor_size =
sizeof(AttentionMetadata) +
workitems.size() * sizeof(AttentionWorkItemGroup) +
reduce_workitems.size() * sizeof(ReductionWorkItemGroup);
auto options =
torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
torch::Tensor metadata_tensor =
torch::empty({metadata_tensor_size}, options);
AttentionMetadata* metadata_ptr = new (metadata_tensor.data_ptr())
AttentionMetadata(input.isa, workitems.size(), reduce_workitems.size(),
cum_split_num, split_kv_q_token_num_threshold);
AttentionWorkItemGroup* workitem_groups_ptr =
metadata_ptr->workitem_groups_ptr;
ReductionWorkItemGroup* reduction_items_ptr =
metadata_ptr->reduction_items_ptr;
std::memcpy(workitem_groups_ptr, workitems.data(),
workitems.size() * sizeof(AttentionWorkItemGroup));
std::memcpy(reduction_items_ptr, reduce_workitems.data(),
reduce_workitems.size() * sizeof(ReductionWorkItemGroup));
int32_t effective_thread_num = 0;
for (; effective_thread_num < thread_num; ++effective_thread_num) {
if (workitem_num_per_thread[effective_thread_num] == 0) {
break;
}
}
std::memcpy(metadata_ptr->cu_workitem_num_per_thread + 1,
workitem_num_per_thread.data(),
workitem_num_per_thread.size() * sizeof(int32_t));
for (int32_t i = 1; i <= thread_num; ++i) {
metadata_ptr->cu_workitem_num_per_thread[i] +=
metadata_ptr->cu_workitem_num_per_thread[i - 1];
}
metadata_ptr->effective_thread_num = effective_thread_num;
{
// when q_tile_size = max_num_q_per_iter, requires max
// attention_scratchpad_size
AttentionScratchPad sc(0, *metadata_ptr, 0x0);
int64_t n = AttentionScheduler::calcu_tile_size_with_constant_q(
cache_size, input.head_dim, input.elem_size, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, kv_len_alignment, max_num_q_per_iter, true);
sc.update(input.head_dim, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, max_num_q_per_iter, n);
metadata_ptr->attention_scratchpad_size_per_thread =
((sc.get_thread_scratchpad_size() + 63) / 64) * 64;
sc.update(0, metadata_ptr->reduction_split_num, input.head_dim,
q_head_per_kv * split_kv_q_token_num_threshold,
input.output_buffer_elem_size);
metadata_ptr->reduction_scratchpad_size_per_kv_head =
((sc.get_reduction_scratchpad_size() + 63) / 64) * 64;
}
int64_t scratchpad_size =
metadata_ptr->attention_scratchpad_size_per_thread *
metadata_ptr->thread_num +
metadata_ptr->reduction_scratchpad_size_per_kv_head *
(use_gqa ? input.num_heads_kv : input.num_heads_q);
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(
scratchpad_size);
// metadata_ptr->print();
// test out of boundary access
// {
// float* cache_ptr =
// DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<float>();
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
// }
// }
return metadata_tensor;
}
FORCE_INLINE static std::pair<int32_t, int32_t> calcu_kv_tile_pos(
int32_t kv_left_pos, int32_t kv_right_pos, int32_t q_left_pos,
int32_t q_right_pos, int32_t sliding_window_left,
int32_t sliding_window_right) {
if (sliding_window_left != -1) {
kv_left_pos = std::max(kv_left_pos, q_left_pos - sliding_window_left);
}
if (sliding_window_right != -1) {
kv_right_pos = std::min(kv_right_pos, q_right_pos + sliding_window_right);
}
return {kv_left_pos, kv_right_pos};
}
FORCE_INLINE static std::pair<int32_t, int32_t> align_kv_tile_pos(
int32_t kv_left_pos, int32_t kv_right_pos, int32_t align_factor) {
kv_left_pos = (kv_left_pos / align_factor) * align_factor;
kv_right_pos =
((kv_right_pos + align_factor - 1) / align_factor) * align_factor;
return {kv_left_pos, kv_right_pos};
}
static int64_t calcu_default_tile_size(int64_t cache_size, int64_t head_dim,
int64_t elem_size,
int64_t q_buffer_elem_size,
int64_t logits_buffer_elem_size,
int64_t output_buffer_elem_size,
int64_t max_num_q_per_iter,
int64_t round_size) {
// For CPU, different from CUDA, Q@K^T results should also be hold in cache,
// using float32. Intermediate outputs should be float32 to be compatible
// with AMX Then the cache includes:
// - Q: q_tile_size * head_dim * q_buffer_elem_size
// - K, V: 2 * k_tile_size * head_dim * elem_size
// - Q@K^T: max_num_q_per_iter * k_tile_size * logits_buffer_elem_size
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size
// By default, let tile_size = q_tile_size = k_tile_size. To record
// is_first_iter states in a static array, require the default tile <= 128 *
// max_num_q_per_iter
int64_t tile_size =
cache_size / (head_dim * (q_buffer_elem_size + 2 * elem_size +
output_buffer_elem_size) +
max_num_q_per_iter * logits_buffer_elem_size);
tile_size = std::min(tile_size, MaxQTileIterNum * max_num_q_per_iter);
int64_t rounded_tile_size = (tile_size / round_size) * round_size;
return std::max(rounded_tile_size, round_size);
}
static int64_t calcu_tile_size_with_constant_q(
int64_t cache_size, int64_t head_dim, int64_t elem_size,
int64_t q_buffer_elem_size, int64_t logits_buffer_elem_size,
int64_t output_buffer_elem_size, int64_t max_num_q_per_iter,
int64_t round_size, int64_t q_tile_size, bool one_round) {
// calculate tile_size with known q_tile_size
// If one_round is True, the outer Q tile loop time is 1, then the K,V will
// not be included in the cache
int64_t tile_size;
if (one_round) {
tile_size =
(cache_size - q_tile_size * head_dim *
(q_buffer_elem_size + output_buffer_elem_size)) /
(logits_buffer_elem_size * max_num_q_per_iter);
} else {
tile_size =
(cache_size - q_tile_size * head_dim *
(q_buffer_elem_size + output_buffer_elem_size)) /
(logits_buffer_elem_size * max_num_q_per_iter +
2 * head_dim * elem_size);
}
int64_t rounded_tile_size = (tile_size / round_size) * round_size;
return std::max(rounded_tile_size, round_size);
}
static int64_t get_available_l2_size() {
static int64_t size = []() {
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
TORCH_CHECK_NE(l2_cache_size, -1);
return l2_cache_size >> 1; // use 50% of L2 cache
}();
return size;
}
private:
int64_t available_cache_size_;
};
struct AttentionInput {
AttentionMetadata* metadata;
int32_t num_tokens;
int32_t num_heads;
int32_t num_kv_heads;
int32_t block_size;
void* query;
int64_t query_num_tokens_stride;
int64_t query_num_heads_stride;
int64_t cache_num_blocks_stride;
int64_t cache_num_kv_heads_stride;
int64_t blt_num_tokens_stride;
void* key_cache;
void* value_cache;
void* output;
int32_t* query_start_loc;
int32_t* seq_lens;
int32_t* block_table;
float* alibi_slopes;
c10::BFloat16* s_aux;
float scale;
bool causal;
int32_t sliding_window_left;
int32_t sliding_window_right;
float softcap;
};
#define DEFINE_CPU_ATTENTION_PARAMS \
q_buffer_t *__restrict__ q_heads_buffer, \
kv_cache_t *__restrict__ k_head_cache_ptr, \
kv_cache_t *__restrict__ v_head_cache_ptr, \
logits_buffer_t *__restrict__ logits_buffer, \
float *__restrict__ partial_q_buffer, float *__restrict__ max_buffer, \
float *__restrict__ sum_buffer, int32_t *__restrict__ block_table, \
const int32_t kv_tile_start_pos, const int32_t kv_tile_end_pos, \
const int32_t kv_tile_token_num, \
const int64_t kv_cache_num_blocks_stride, const int32_t q_head_num, \
const int32_t q_token_num, const int32_t q_tile_start_pos, \
const int32_t q_heads_per_kv, const int32_t block_size, \
const int32_t left_window_size, const int32_t right_window_size, \
float scale, const float softcap_scale, \
const float *__restrict__ alibi_slopes, const bool is_first_iter, \
const bool use_sink, const bool debug_info
#define CPU_ATTENTION_PARAMS \
q_heads_buffer, k_head_cache_ptr, v_head_cache_ptr, logits_buffer, \
partial_q_buffer, max_buffer, sum_buffer, block_table, \
kv_tile_start_pos, kv_tile_end_pos, kv_tile_token_num, \
kv_cache_num_blocks_stride, q_head_num, q_token_num, q_tile_start_pos, \
q_heads_per_kv, block_size, left_window_size, right_window_size, scale, \
softcap_scale, alibi_slopes, is_first_iter, use_sink, debug_info
enum class AttentionGemmPhase { QK, PV };
template <typename T>
struct VecTypeTrait {
using vec_t = void;
};
template <>
struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;
};
template <typename T>
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
template <typename attention_impl_t>
class AttentionMainLoop {
public:
using query_t = typename attention_impl_t::query_t;
using q_buffer_t = typename attention_impl_t::q_buffer_t;
using kv_cache_t = typename attention_impl_t::kv_cache_t;
using logits_buffer_t = typename attention_impl_t::logits_buffer_t;
using partial_output_buffer_t =
typename attention_impl_t::partial_output_buffer_t;
using prob_buffer_t = typename attention_impl_t::prob_buffer_t;
static constexpr int64_t max_q_head_num_per_iter =
attention_impl_t::MaxQHeadNumPerIteration;
static constexpr int64_t blocksize_alignment =
attention_impl_t::BlockSizeAlignment;
static constexpr int64_t headdim_alignment =
attention_impl_t::HeadDimAlignment;
static constexpr int64_t head_dim = attention_impl_t::HeadDim;
static constexpr ISA ISAType = attention_impl_t::ISAType;
static constexpr bool scale_on_logits =
attention_impl_t::scale_on_logits; // apply scale on logits, otherwise
// apply scale on q_buffer
template <typename tile_gemm_t>
class Attention {
public:
// Args:
// - q_heads_buffer: [MaxQHeadNumPerIteration, head_dim]
// - k_head_cache_ptr: [num_blocks, block_size * head_dim]
// - v_head_cache_ptr: [num_blocks, block_size * head_dim]
// - logits_buffer: [MaxQHeadNumPerIteration, kv_tile_token_num], store Q@K
// - logits partial_q_buffer: [MaxQHeadNumPerIteration, head_dim], store
// partial output
// - max_buffer: [MaxQHeadNumPerIteration, 1], store max logits
// - sum_buffer: [MaxQHeadNumPerIteration, 1], store sum of exp
// - block_table
// - kv_tile_start_pos: start position of KV cache, aligned to
// BlockSizeAlignment
// - kv_tile_end_pos: end position of KV cache, aligned to
// BlockSizeAlignment
// - kv_tile_token_num: KV token num, aligned to BlockSizeAlignment
// - kv_cache_num_blocks_stride
// - q_head_num: head num of q_tile
// - q_token_num: token num of q_tile, should be q_head_num /
// q_heads_per_kv
// - q_tile_start_pos: start pos of the first token in q_heads_buffer
// - q_heads_per_kv
// - block_size
// - left_window_size
// - right_window_size
// - scale
// - softcap_scale
// - alibi_slopes
// - is_first_iter
// - use_sink
// - debug_info
void operator()(DEFINE_CPU_ATTENTION_PARAMS) {
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
const int64_t k_cache_token_group_stride =
attention_impl_t::k_cache_token_group_stride(block_size);
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
const int64_t v_cache_token_group_stride =
attention_impl_t::v_cache_token_group_stride(block_size);
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
const int64_t v_cache_head_group_stride =
attention_impl_t::v_cache_head_group_stride(block_size);
const int32_t token_group_num = kv_tile_token_num / blocksize_alignment;
const int32_t token_group_num_per_block =
block_size / blocksize_alignment;
const int32_t start_block_idx = kv_tile_start_pos / block_size;
const int32_t start_block_offset = kv_tile_start_pos % block_size;
const int32_t start_block_group_offset =
start_block_offset / blocksize_alignment;
const int32_t end_block_idx =
(kv_tile_start_pos + kv_tile_token_num - 1) / block_size + 1;
// compute Q@K logits
{
int32_t curr_group_offset =
start_block_group_offset * k_cache_token_group_stride;
int32_t curr_group_num_in_block =
token_group_num_per_block - start_block_group_offset;
int32_t remaining_group_num = token_group_num;
logits_buffer_t* curr_logits_buffer = logits_buffer;
for (int32_t block_idx = start_block_idx; block_idx < end_block_idx;
++block_idx) {
int32_t physical_block_idx = block_table[block_idx];
kv_cache_t* k_cache_block_ptr =
k_head_cache_ptr +
physical_block_idx * kv_cache_num_blocks_stride +
curr_group_offset;
curr_group_num_in_block =
std::min(remaining_group_num, curr_group_num_in_block);
for (int32_t block_group_idx = 0;
block_group_idx < curr_group_num_in_block; ++block_group_idx) {
// logits_tile = q_tile @ k_tile, [MaxQHeadNumPerIteration,
// BlockSizeAlignment] = [MaxQHeadNumPerIteration, head_dim] @
// [head_dim, BlockSizeAlignment]
// By default, logits_buffer, q_buffer and k_cache are row-major,
// but may be packed by ISA implementation.
tile_gemm_t::template gemm<AttentionGemmPhase::QK, head_dim>(
q_head_num, q_heads_buffer, k_cache_block_ptr,
curr_logits_buffer, head_dim, block_size, kv_tile_token_num,
block_size, head_dim, false);
if constexpr (scale_on_logits) {
float* __restrict__ scale_curr_logits_buffer = curr_logits_buffer;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t i = 0; i < q_head_num; ++i) {
static_assert(blocksize_alignment % 16 == 0);
constexpr int32_t vec_num = blocksize_alignment / 16;
vec_op::unroll_loop<int32_t, vec_num>([&](int32_t vec_idx) {
vec_op::FP32Vec16 vec(scale_curr_logits_buffer +
vec_idx * 16);
vec = vec * scale_vec;
vec.save(scale_curr_logits_buffer + vec_idx * 16);
});
scale_curr_logits_buffer += kv_tile_token_num;
}
}
// Move buffer ptrs
k_cache_block_ptr += k_cache_token_group_stride;
curr_logits_buffer += blocksize_alignment;
}
// Update
remaining_group_num -= curr_group_num_in_block;
curr_group_offset = 0;
curr_group_num_in_block = token_group_num_per_block;
}
}
// process logits
{
// if (debug_info){
// print_logits("raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
// }
if (softcap_scale != 0.0f) {
apply_softcap(logits_buffer, kv_tile_token_num, q_head_num,
kv_tile_token_num, softcap_scale);
// print_logits("softcap raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
}
if (alibi_slopes != nullptr) {
apply_alibi_slopes(logits_buffer, alibi_slopes, kv_tile_token_num,
q_tile_start_pos, kv_tile_start_pos, q_token_num,
kv_tile_token_num, q_heads_per_kv);
// print_logits("alibi raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
}
apply_mask(logits_buffer, kv_tile_token_num, q_tile_start_pos,
kv_tile_start_pos, kv_tile_end_pos, q_token_num,
q_heads_per_kv, left_window_size, right_window_size);
// if (debug_info){
// print_logits("masked logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
// print_logits("old_max", max_buffer, 1, q_head_num, q_head_num);
// print_logits("old_sum", sum_buffer, 1, q_head_num, q_head_num);
// }
apply_softmax(logits_buffer, partial_q_buffer, max_buffer, sum_buffer,
kv_tile_token_num, q_head_num, kv_tile_token_num,
is_first_iter, use_sink);
// if (debug_info){
// print_logits("softmax logits",
// reinterpret_cast<prob_buffer_t*>(logits_buffer), q_head_num,
// kv_tile_token_num, kv_tile_token_num * sizeof(logits_buffer_t) /
// sizeof(prob_buffer_t));
// print_logits("new_max", max_buffer, 1, q_head_num, q_head_num);
// print_logits("new_sum", sum_buffer, 1, q_head_num, q_head_num);
// }
}
// compute P@V
{
int32_t curr_group_offset =
start_block_group_offset * v_cache_token_group_stride;
int32_t curr_group_num_in_block =
token_group_num_per_block - start_block_group_offset;
int32_t remaining_group_num = token_group_num;
int32_t head_dim_group_num = head_dim / headdim_alignment;
prob_buffer_t* curr_prob_buffer =
reinterpret_cast<prob_buffer_t*>(logits_buffer);
int64_t prob_buffer_stride =
kv_tile_token_num *
(sizeof(logits_buffer_t) / sizeof(prob_buffer_t));
partial_output_buffer_t* curr_partial_q_buffer = partial_q_buffer;
bool accum_c = !is_first_iter;
for (int32_t block_idx = start_block_idx; block_idx < end_block_idx;
++block_idx) {
int32_t physical_block_idx = block_table[block_idx];
kv_cache_t* v_cache_block_ptr =
v_head_cache_ptr +
physical_block_idx * kv_cache_num_blocks_stride +
curr_group_offset;
curr_group_num_in_block =
std::min(remaining_group_num, curr_group_num_in_block);
int32_t curr_token_num =
curr_group_num_in_block * blocksize_alignment;
for (int32_t head_dim_group_idx = 0;
head_dim_group_idx < head_dim_group_num; ++head_dim_group_idx) {
// output_tile = p_tile @ v_tile, [MaxQHeadNumPerIteration,
// HeadDimAlignment] = [MaxQHeadNumPerIteration, block_size] @
// [block_size, HeadDimAlignment]
tile_gemm_t::template gemm<AttentionGemmPhase::PV, -1>(
q_head_num, curr_prob_buffer, v_cache_block_ptr,
curr_partial_q_buffer, prob_buffer_stride, head_dim, head_dim,
block_size, curr_token_num, accum_c);
// Update
curr_partial_q_buffer += headdim_alignment;
v_cache_block_ptr += v_cache_head_group_stride;
}
// Update
remaining_group_num -= curr_group_num_in_block;
curr_group_offset = 0;
curr_group_num_in_block = token_group_num_per_block;
curr_prob_buffer += curr_token_num;
curr_partial_q_buffer = partial_q_buffer;
accum_c = true;
}
}
// if (debug_info) {
// print_logits("output", partial_q_buffer, q_head_num, head_dim,
// head_dim);
// }
}
void apply_mask(logits_buffer_t* __restrict__ logits_buffer,
const int64_t logits_buffer_stride,
const int32_t q_tile_start_pos,
const int32_t kv_tile_start_pos,
const int32_t kv_tile_end_pos, const int32_t q_token_num,
const int32_t q_heads_per_kv,
const int32_t sliding_window_left,
const int32_t sliding_window_right) {
// Apply mask
constexpr logits_buffer_t neg_inf =
-std::numeric_limits<logits_buffer_t>::infinity();
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
int32_t curr_token_pos = q_tile_start_pos;
for (int32_t token_idx = 0; token_idx < q_token_num; ++token_idx) {
int32_t left_kv_pos = [&]() {
int32_t pos = kv_tile_start_pos;
if (sliding_window_left != -1) {
pos = std::max(pos, curr_token_pos - sliding_window_left);
}
return pos;
}();
int32_t right_kv_pos = [&]() {
int32_t pos = kv_tile_end_pos;
if (sliding_window_right != -1) {
pos = std::min(pos,
std::max(kv_tile_start_pos,
curr_token_pos + sliding_window_right + 1));
}
return pos;
}();
int32_t left_invalid_token_num = left_kv_pos - kv_tile_start_pos;
int32_t right_invalid_token_num = kv_tile_end_pos - right_kv_pos;
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
logits_buffer_t* __restrict__ curr_logits_buffer_tail =
curr_logits_buffer + right_kv_pos - kv_tile_start_pos;
for (int32_t i = 0; i < left_invalid_token_num; ++i) {
curr_logits_buffer[i] = neg_inf;
}
for (int32_t i = 0; i < right_invalid_token_num; ++i) {
curr_logits_buffer_tail[i] = neg_inf;
}
curr_logits_buffer += logits_buffer_stride;
}
++curr_token_pos;
}
}
void apply_softmax(logits_buffer_t* __restrict__ logits_buffer,
float* __restrict__ partial_q_buffer,
float* __restrict__ max_buffer,
float* __restrict__ sum_buffer,
const int64_t logits_buffer_stride, int32_t q_head_num,
int32_t kv_tile_token_num, bool is_first_iter,
bool use_sink) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
using prob_buffer_vec_t = typename VecTypeTrait<prob_buffer_t>::vec_t;
static_assert(sizeof(prob_buffer_t) <= sizeof(logits_buffer_t));
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
float* __restrict__ curr_partial_q_buffer = partial_q_buffer;
const int32_t vec_num = kv_tile_token_num / 16;
const int32_t head_vec_num = head_dim / 16;
for (int32_t i = 0; i < q_head_num; ++i) {
float init_max_val = max_buffer[i];
float init_sum_val = sum_buffer[i];
// apply scale and compute max
vec_op::FP32Vec16 max_vec(init_max_val);
{
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
max_vec = vec.max(max_vec);
curr_logits_buffer_iter += 16;
}
}
float new_max_val = max_vec.reduce_max();
float rescale_factor = init_max_val - new_max_val;
// use same rescale threshold with FA4.
// https://github.com/Dao-AILab/flash-attention/blob/1b8e1e641c6a179be9a0538b7f40fd595050b735/flash_attn/cute/flash_fwd_sm100.py#L1271
bool need_rescale = rescale_factor < -8.0;
if (!need_rescale) {
new_max_val = init_max_val;
} else {
max_buffer[i] = new_max_val;
}
// sub max, compute exp and sum
max_vec = vec_op::FP32Vec16(new_max_val);
vec_op::FP32Vec16 sum_vec(0.0);
{
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
prob_buffer_t* __restrict__ curr_prob_buffer_iter =
reinterpret_cast<prob_buffer_t*>(curr_logits_buffer);
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec - max_vec;
// compute exp
#ifdef DEFINE_FAST_EXP
vec = fast_exp(vec);
prob_buffer_vec_t output_vec(vec);
output_vec.save(curr_prob_buffer_iter);
#else
vec.save(curr_logits_buffer_iter);
for (int32_t k = 0; k < 16; ++k) {
curr_logits_buffer_iter[k] = std::exp(curr_logits_buffer_iter[k]);
}
vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
#endif
sum_vec = sum_vec + vec;
curr_logits_buffer_iter += 16;
curr_prob_buffer_iter += 16;
}
}
float new_sum_val = sum_vec.reduce_sum();
// rescale sum and partial outputs
if (need_rescale) {
// compute rescale factor
#ifdef DEFINE_FAST_EXP
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
rescale_factor_vec = fast_exp(rescale_factor_vec);
rescale_factor = rescale_factor_vec.get_last_elem();
#else
rescale_factor = std::exp(rescale_factor);
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
#endif
// rescale sum
new_sum_val += rescale_factor * init_sum_val;
// rescale output
if (!is_first_iter) {
float* __restrict__ curr_partial_q_buffer_iter =
curr_partial_q_buffer;
for (int32_t j = 0; j < head_vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_partial_q_buffer_iter);
vec = vec * rescale_factor_vec;
vec.save(curr_partial_q_buffer_iter);
curr_partial_q_buffer_iter += 16;
}
}
} else {
new_sum_val += init_sum_val;
}
sum_buffer[i] = new_sum_val;
curr_logits_buffer += logits_buffer_stride;
curr_partial_q_buffer += head_dim;
}
}
void apply_softcap(logits_buffer_t* __restrict__ logits_buffer,
const int64_t logits_buffer_stride, int32_t q_head_num,
int32_t kv_tile_token_num, float softcap_scale) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
float inv_softcap_scale = 1.0 / softcap_scale;
vec_op::FP32Vec16 softcap_scale_vec(softcap_scale);
vec_op::FP32Vec16 inv_softcap_scale_vec(inv_softcap_scale);
vec_op::FP32Vec16 ones_vec(1.0);
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
const int32_t vec_num = kv_tile_token_num / 16;
for (int32_t i = 0; i < q_head_num; ++i) {
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec * inv_softcap_scale_vec;
#ifdef DEFINE_FAST_EXP
vec = fast_exp(vec);
vec_op::FP32Vec16 inv_vec = ones_vec / vec;
vec = (vec - inv_vec) / (vec + inv_vec);
#else
vec.save(curr_logits_buffer_iter);
for (int k = 0; k < 16; ++k) {
curr_logits_buffer_iter[k] = std::tanh(curr_logits_buffer_iter[k]);
}
vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
#endif
vec = vec * softcap_scale_vec;
vec.save(curr_logits_buffer_iter);
curr_logits_buffer_iter += 16;
}
curr_logits_buffer += logits_buffer_stride;
}
}
void apply_alibi_slopes(logits_buffer_t* __restrict__ logits_buffer,
const float* __restrict__ alibi_slopes,
const int64_t logits_buffer_stride,
const int32_t q_tile_start_pos,
const int32_t kv_tile_start_pos,
const int32_t q_token_num,
const int32_t kv_tile_token_num,
const int32_t q_heads_per_kv) {
alignas(64) constexpr float initial_arange_vals[16] = {
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
const int32_t vec_num = kv_tile_token_num / 16;
vec_op::FP32Vec16 initial_arange_vals_vec(initial_arange_vals);
initial_arange_vals_vec =
initial_arange_vals_vec + vec_op::FP32Vec16((float)kv_tile_start_pos);
vec_op::FP32Vec16 pos_offset_vec(16.0);
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
for (int32_t i = 0; i < q_token_num; ++i) {
vec_op::FP32Vec16 curr_q_pos_vec((float)(i + q_tile_start_pos));
for (int32_t j = 0; j < q_heads_per_kv; ++j) {
vec_op::FP32Vec16 alibi_scale_vec(alibi_slopes[j]);
vec_op::FP32Vec16 curr_kv_pos_vec(initial_arange_vals_vec);
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t k = 0; k < vec_num; ++k) {
vec_op::FP32Vec16 alibi_bias_vec =
alibi_scale_vec * (curr_kv_pos_vec - curr_q_pos_vec);
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec + alibi_bias_vec;
vec.save(curr_logits_buffer_iter);
curr_kv_pos_vec = curr_kv_pos_vec + pos_offset_vec;
curr_logits_buffer_iter += 16;
}
curr_logits_buffer += logits_buffer_stride;
}
}
}
};
public:
void operator()(const AttentionInput* input) {
const int thread_num = omp_get_max_threads();
TORCH_CHECK_EQ(input->metadata->thread_num, thread_num);
std::atomic<int32_t> guard_counter(0);
std::atomic<int32_t>* guard_counter_ptr = &guard_counter;
#pragma omp parallel for schedule(static, 1)
for (int thread_id = 0; thread_id < thread_num; ++thread_id) {
AttentionMetadata& metadata = *input->metadata;
if (metadata.workitem_group_num == 0) {
continue;
}
attention_impl_t attn_impl;
// general information
const int32_t q_head_num = input->num_heads;
const int32_t kv_head_num = input->num_kv_heads;
const int32_t q_heads_per_kv = q_head_num / kv_head_num;
const bool use_gqa =
(max_q_head_num_per_iter % q_heads_per_kv == 0) ? true : false;
const int32_t actual_kv_head_num = use_gqa ? kv_head_num : q_head_num;
const int32_t actual_q_heads_per_kv = use_gqa ? q_heads_per_kv : 1;
TORCH_CHECK_LE(actual_q_heads_per_kv, max_q_head_num_per_iter);
const int32_t max_q_token_num_per_iter =
max_q_head_num_per_iter / actual_q_heads_per_kv;
const int64_t q_token_num_stride = input->query_num_tokens_stride;
const int64_t q_head_num_stride = input->query_num_heads_stride;
const int64_t kv_cache_head_num_stride = input->cache_num_kv_heads_stride;
const int64_t kv_cache_block_num_stride = input->cache_num_blocks_stride;
const int32_t sliding_window_left = input->sliding_window_left;
const int32_t sliding_window_right = input->sliding_window_right;
const int32_t block_size = input->block_size;
const float scale = input->scale;
const float softcap_scale = input->softcap;
const float* alibi_slopes = input->alibi_slopes;
const c10::BFloat16* s_aux = input->s_aux;
const bool casual = input->causal;
int32_t* const block_table = input->block_table;
const int64_t block_table_stride = input->blt_num_tokens_stride;
// init buffers
void* scratchpad_ptr =
DNNLScratchPadManager::get_dnnl_scratchpad_manager()
->get_data<void>();
AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr);
const int32_t total_reduction_split_num = metadata.reduction_split_num;
if (metadata.reduction_split_num > 0) {
// reset split flag
for (int32_t head_idx = thread_id; head_idx < actual_kv_head_num;
head_idx += thread_num) {
buffer_manager.update(head_idx, total_reduction_split_num, head_dim,
0, sizeof(partial_output_buffer_t));
volatile bool* __restrict__ curr_flag_ptr =
buffer_manager.get_reduce_flag_buffer();
for (int32_t split_idx = 0; split_idx < total_reduction_split_num;
++split_idx) {
curr_flag_ptr[split_idx] = false;
}
}
}
const int64_t available_cache_size =
AttentionScheduler::get_available_l2_size();
const int32_t default_tile_size =
AttentionScheduler::calcu_default_tile_size(
available_cache_size, head_dim, sizeof(kv_cache_t),
sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
max_q_head_num_per_iter);
const int32_t default_q_tile_token_num =
default_tile_size / actual_q_heads_per_kv;
AttentionWorkItemGroup* const workitem_groups =
metadata.workitem_groups_ptr;
const int32_t* cu_workitem_num_per_thread =
metadata.cu_workitem_num_per_thread;
ReductionWorkItemGroup* const reduction_items =
metadata.reduction_items_ptr;
const int32_t effective_thread_num = metadata.effective_thread_num;
const int32_t reduction_item_num = metadata.reduction_item_num;
const int32_t split_kv_q_token_num_threshold =
metadata.split_kv_q_token_num_threshold;
const int32_t workitem_groups_counter_num =
actual_kv_head_num * effective_thread_num;
const int32_t reduction_items_counter_num =
actual_kv_head_num * reduction_item_num;
const int32_t total_counter_num =
workitem_groups_counter_num + reduction_items_counter_num;
if (metadata.reduction_split_num > 0) {
++(*guard_counter_ptr);
while (guard_counter_ptr->load() != thread_num) {
#ifdef FAST_SPINNING
FAST_SPINNING
#else
std::this_thread::yield();
#endif
}
}
// main loop
for (;;) {
int64_t task_idx = metadata.acquire_counter();
if (task_idx >= total_counter_num) {
// no more tasks, leave loop
break;
}
if (task_idx < workitem_groups_counter_num) {
// attention task
// map task_idx to workitem_groups
const int32_t kv_head_idx = task_idx / effective_thread_num;
const int32_t thread_offset = task_idx % effective_thread_num;
AttentionWorkItemGroup* const curr_workitem_groups =
workitem_groups + cu_workitem_num_per_thread[thread_offset];
const int32_t curr_workitem_groups_num =
cu_workitem_num_per_thread[thread_offset + 1] -
cu_workitem_num_per_thread[thread_offset];
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
for (int32_t workitem_group_idx = 0;
workitem_group_idx < curr_workitem_groups_num;
++workitem_group_idx) {
AttentionWorkItemGroup* const current_workitem_group =
&curr_workitem_groups[workitem_group_idx];
const int32_t current_group_idx = current_workitem_group->req_id;
const int32_t kv_start_pos =
current_workitem_group->kv_split_pos_start;
const int32_t kv_end_pos = current_workitem_group->kv_split_pos_end;
const int32_t curr_spilt_id = current_workitem_group->split_id;
const int32_t q_token_id_start =
current_workitem_group->q_token_id_start;
const int32_t q_token_num = current_workitem_group->q_token_num;
// taskgroup general information
const int32_t q_end = input->query_start_loc[current_group_idx + 1];
const int32_t q_start = input->query_start_loc[current_group_idx];
const int32_t seq_len = input->seq_lens[current_group_idx];
const int32_t q_start_pos =
(casual ? seq_len - (q_end - q_start) : 0);
const int32_t block_num = (seq_len + block_size - 1) / block_size;
// Only apply sink for the first KV split
bool use_sink = (s_aux != nullptr &&
current_workitem_group->local_split_id == 0);
for (int32_t q_token_offset = 0; q_token_offset < q_token_num;
q_token_offset += default_q_tile_token_num) {
bool first_iter_flag[AttentionScheduler::MaxQTileIterNum];
for (int32_t i = 0; i < AttentionScheduler::MaxQTileIterNum;
++i) {
first_iter_flag[i] = true;
}
const int32_t q_token_start_idx =
q_start + q_token_offset + q_token_id_start;
const int32_t actual_q_token_num = std::min(
default_q_tile_token_num, q_token_num - q_token_offset);
const int32_t q_head_tile_size =
actual_q_token_num * actual_q_heads_per_kv;
const int32_t rounded_q_head_tile_size =
((q_head_tile_size + max_q_head_num_per_iter - 1) /
max_q_head_num_per_iter) *
max_q_head_num_per_iter;
const int32_t kv_tile_size =
AttentionScheduler::calcu_tile_size_with_constant_q(
available_cache_size, head_dim, sizeof(kv_cache_t),
sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
blocksize_alignment, rounded_q_head_tile_size,
rounded_q_head_tile_size <= max_q_head_num_per_iter);
// update buffers
buffer_manager.update(
head_dim, sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
rounded_q_head_tile_size, kv_tile_size);
q_buffer_t* q_buffer = buffer_manager.get_q_buffer<q_buffer_t>();
float* logits_buffer = buffer_manager.get_logits_buffer();
float* partial_q_buffer = buffer_manager.get_output_buffer();
float* max_buffer = buffer_manager.get_max_buffer();
float* sum_buffer = buffer_manager.get_sum_buffer();
const int32_t q_tile_start_pos =
q_start_pos + q_token_offset + q_token_id_start;
const int32_t q_tile_end_pos =
q_tile_start_pos + actual_q_token_num;
const auto [kv_tile_start_pos, kv_tile_end_pos] =
AttentionScheduler::calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_start_pos,
q_tile_end_pos, sliding_window_left,
sliding_window_right);
const auto [rounded_kv_tile_start_pos, rounded_kv_tile_end_pos] =
AttentionScheduler::align_kv_tile_pos(
kv_tile_start_pos, kv_tile_end_pos, blocksize_alignment);
int32_t curr_kv_head_idx =
use_gqa ? kv_head_idx
: (kv_head_idx /
q_heads_per_kv); // for GQA disabled case
// std::printf("thread_id: %d, req_id: %d, q_token_start: %d,
// q_token_end: %d, q_head_start: %d, q_head_end: %d, kv_head_idx:
// %d, kv_pos_start: %d, kv_pos_end: %d\n",
// thread_id, current_group_idx,
// q_token_start_idx, q_token_start_idx +
// actual_q_token_num, q_head_start_idx,
// q_head_start_idx + actual_q_heads_per_kv,
// curr_kv_head_idx, kv_tile_start_pos,
// kv_tile_end_pos);
// move buffers
kv_cache_t* curr_k_cache =
reinterpret_cast<kv_cache_t*>(input->key_cache) +
curr_kv_head_idx * kv_cache_head_num_stride;
kv_cache_t* curr_v_cache =
reinterpret_cast<kv_cache_t*>(input->value_cache) +
curr_kv_head_idx * kv_cache_head_num_stride;
query_t* const q_tile_ptr =
reinterpret_cast<query_t*>(input->query) +
q_token_start_idx * q_token_num_stride +
q_head_start_idx * q_head_num_stride;
size_t output_buffer_offset =
q_token_start_idx * q_head_num * head_dim +
q_head_start_idx * head_dim;
int32_t* curr_block_table =
block_table + current_group_idx * block_table_stride;
const float* curr_alibi_slopes =
(alibi_slopes != nullptr ? alibi_slopes + q_head_start_idx
: nullptr);
const c10::BFloat16* curr_s_aux =
(s_aux != nullptr ? s_aux + q_head_start_idx : nullptr);
// copy the Q tile to q_buffer, the logical layout of q_buffer is
// [actual_q_token_num, actual_q_heads_per_kv, head_dim]
{
attn_impl.copy_q_heads_tile(
q_tile_ptr, q_buffer, actual_q_token_num,
actual_q_heads_per_kv, q_token_num_stride,
q_head_num_stride, scale);
}
if (use_sink) {
alignas(64) float s_aux_fp32[16];
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
vec_op::FP32Vec16 vec_fp32(vec_bf16);
vec_fp32.save(s_aux_fp32);
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
++token_idx) {
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
++head_idx) {
curr_sum_buffer[head_idx] = 1.0f;
curr_max_buffer[head_idx] = s_aux_fp32[head_idx];
}
curr_sum_buffer += actual_q_heads_per_kv;
curr_max_buffer += actual_q_heads_per_kv;
}
} else {
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
++token_idx) {
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
++head_idx) {
curr_sum_buffer[head_idx] = 0.0f;
curr_max_buffer[head_idx] =
std::numeric_limits<float>::lowest();
}
curr_sum_buffer += actual_q_heads_per_kv;
curr_max_buffer += actual_q_heads_per_kv;
}
}
// compute loop
for (int32_t kv_tile_pos = rounded_kv_tile_start_pos;
kv_tile_pos < rounded_kv_tile_end_pos;
kv_tile_pos += kv_tile_size) {
const int32_t kv_tile_pos_left = kv_tile_pos;
const int32_t kv_tile_pos_right = std::min(
kv_tile_pos_left + kv_tile_size, rounded_kv_tile_end_pos);
for (int32_t q_head_tile_token_offset = 0;
q_head_tile_token_offset < actual_q_token_num;
q_head_tile_token_offset += max_q_token_num_per_iter) {
const int32_t q_tile_pos_left =
q_tile_start_pos + q_head_tile_token_offset;
const int32_t q_tile_token_num =
std::min(max_q_token_num_per_iter,
actual_q_token_num - q_head_tile_token_offset);
const int32_t q_tile_head_offset =
q_head_tile_token_offset * actual_q_heads_per_kv;
const int32_t q_tile_head_num =
q_tile_token_num * actual_q_heads_per_kv;
const int32_t q_tile_pos_right =
q_tile_pos_left + q_tile_token_num;
const auto [actual_kv_tile_pos_left,
actual_kv_tile_pos_right] =
AttentionScheduler::calcu_kv_tile_pos(
kv_tile_pos_left, kv_tile_pos_right, q_tile_pos_left,
q_tile_pos_right, sliding_window_left,
sliding_window_right);
const int32_t q_iter_idx =
q_head_tile_token_offset / max_q_token_num_per_iter;
if (actual_kv_tile_pos_right <= actual_kv_tile_pos_left) {
continue;
}
// align kv_pos to blocksize_alignment
const auto [aligned_actual_kv_tile_pos_left,
aligned_actual_kv_tile_pos_right] =
AttentionScheduler::align_kv_tile_pos(
actual_kv_tile_pos_left, actual_kv_tile_pos_right,
blocksize_alignment);
const int32_t actual_kv_token_num =
aligned_actual_kv_tile_pos_right -
aligned_actual_kv_tile_pos_left;
// std::printf("\tq_iter_idx: %d, q_token_start: %d,
// q_token_end: %d, q_token_num: %d, q_head_num: %d,
// q_pos_start: %d, q_pos_end: %d, kv_pos_start: %d,
// kv_pos_end: %d\n",
// q_iter_idx, q_token_start_idx +
// q_head_tile_token_offset, q_token_start_idx +
// q_head_tile_token_offset + q_tile_token_num,
// q_tile_token_num, q_tile_head_num,
// q_tile_pos_left, q_tile_pos_right,
// aligned_actual_kv_tile_pos_left,
// aligned_actual_kv_tile_pos_right);
// Move buffers
q_buffer_t* curr_q_heads_buffer =
q_buffer + q_tile_head_offset * head_dim;
float* curr_partial_q_buffer =
partial_q_buffer + q_tile_head_offset * head_dim;
float* curr_max_buffer = max_buffer + q_tile_head_offset;
float* curr_sum_buffer = sum_buffer + q_tile_head_offset;
bool debug_info = false;
// bool debug_info = (
// q_head_start_idx == 4 &&
// (q_token_start_idx + q_head_tile_token_offset) <=
// 4
// && (q_token_start_idx + q_head_tile_token_offset +
// q_tile_token_num) > 4
// );
// if (debug_info) {
// std::printf("\tq_iter_idx: %d, q_token_start: %d,"
// "q_token_end: %d, q_token_num: %d, q_head_num: %d,"
// "q_pos_start: %d, q_pos_end: %d, kv_pos_start: %d,"
// "kv_pos_end: %d\n",
// q_iter_idx, q_token_start_idx +
// q_head_tile_token_offset, q_token_start_idx
// + q_head_tile_token_offset +
// q_tile_token_num, q_tile_token_num,
// q_tile_head_num, q_tile_pos_left,
// q_tile_pos_right,
// aligned_actual_kv_tile_pos_left,
// aligned_actual_kv_tile_pos_right);
// }
attn_impl.template execute_attention<Attention>(
curr_q_heads_buffer, curr_k_cache, curr_v_cache,
logits_buffer, curr_partial_q_buffer, curr_max_buffer,
curr_sum_buffer, curr_block_table,
aligned_actual_kv_tile_pos_left,
aligned_actual_kv_tile_pos_right, actual_kv_token_num,
kv_cache_block_num_stride, q_tile_head_num,
q_tile_token_num, q_tile_pos_left, actual_q_heads_per_kv,
block_size, sliding_window_left, sliding_window_right,
scale, softcap_scale, curr_alibi_slopes,
first_iter_flag[q_iter_idx], use_sink, debug_info);
first_iter_flag[q_iter_idx] = false;
}
}
// write back partial results to output buffer or reduction buffer
{
if (curr_spilt_id == -1) {
final_output(partial_q_buffer,
reinterpret_cast<query_t*>(input->output) +
output_buffer_offset,
sum_buffer, actual_q_heads_per_kv,
actual_q_token_num, q_head_num);
} else {
const int32_t stride =
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
buffer_manager.update(kv_head_idx, total_reduction_split_num,
head_dim, stride, sizeof(float));
volatile bool* split_flag_buffer =
buffer_manager.get_reduce_flag_buffer() + curr_spilt_id;
float* split_output_buffer =
buffer_manager.get_reduce_output_buffer() +
curr_spilt_id * stride * head_dim;
float* split_max_buffer =
buffer_manager.get_reduce_max_buffer() +
curr_spilt_id * stride;
float* split_sum_buffer =
buffer_manager.get_reduce_sum_buffer() +
curr_spilt_id * stride;
partial_output(partial_q_buffer, max_buffer, sum_buffer,
q_head_tile_size, split_output_buffer,
split_max_buffer, split_sum_buffer,
split_flag_buffer);
}
}
}
}
} else {
task_idx -= workitem_groups_counter_num;
const int32_t kv_head_idx = task_idx / reduction_item_num;
const int32_t item_offset = task_idx % reduction_item_num;
ReductionWorkItemGroup* const curr_workitem_groups =
reduction_items + item_offset;
const int32_t curr_output_token_idx =
curr_workitem_groups->q_token_id_start;
const int32_t curr_output_token_num =
curr_workitem_groups->q_token_id_num;
const int32_t curr_split_id = curr_workitem_groups->split_start_id;
const int32_t curr_split_num = curr_workitem_groups->split_num;
const int32_t current_group_idx = curr_workitem_groups->req_id;
const int32_t curr_output_head_num =
curr_output_token_num * actual_q_heads_per_kv;
const int32_t q_start = input->query_start_loc[current_group_idx];
const int32_t q_token_start_idx = q_start + curr_output_token_idx;
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
size_t output_buffer_offset =
q_token_start_idx * q_head_num * head_dim +
q_head_start_idx * head_dim;
const int32_t stride =
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
buffer_manager.update(kv_head_idx, total_reduction_split_num,
head_dim, stride, sizeof(float));
volatile bool* split_flag_buffer =
buffer_manager.get_reduce_flag_buffer() + curr_split_id;
float* split_output_buffer =
buffer_manager.get_reduce_output_buffer() +
curr_split_id * stride * head_dim;
float* split_max_buffer =
buffer_manager.get_reduce_max_buffer() + curr_split_id * stride;
float* split_sum_buffer =
buffer_manager.get_reduce_sum_buffer() + curr_split_id * stride;
reduce_splits(split_output_buffer, split_max_buffer, split_sum_buffer,
split_flag_buffer, stride, curr_output_head_num,
curr_split_num);
final_output(
split_output_buffer,
reinterpret_cast<query_t*>(input->output) + output_buffer_offset,
split_sum_buffer, actual_q_heads_per_kv, curr_output_token_num,
q_head_num);
}
}
}
// Reset counter for next call
input->metadata->reset_counter();
}
void reduce_splits(float* __restrict__ split_output_buffer,
float* __restrict__ split_max_buffer,
float* __restrict__ split_sum_buffer,
volatile bool* __restrict__ flags,
const int32_t head_num_per_split,
const int32_t curr_head_num, const int32_t split_num) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
// restrict curr_head_num <= 16 in the scheduler
// elems in split_max_buffer, split_sum_buffer are not cache alignment, use
// local buffers to reduce false-sharing
alignas(64) float local_max[16];
alignas(64) float local_sum[16];
float* __restrict__ curr_split_output_buffer = split_output_buffer;
float* __restrict__ curr_split_max_buffer = split_max_buffer;
float* __restrict__ curr_split_sum_buffer = split_sum_buffer;
constexpr int32_t head_dim_group_num = head_dim / 16;
for (int32_t split_idx = 0; split_idx < split_num; ++split_idx) {
while (!flags[split_idx]) {
#ifdef FAST_SPINNING
FAST_SPINNING
#else
std::this_thread::yield();
#endif
}
std::atomic_thread_fence(std::memory_order_acquire);
if (split_idx > 0) {
float* __restrict__ curr_output_buffer = split_output_buffer;
float* __restrict__ curr_split_output_buffer_iter =
curr_split_output_buffer;
for (int32_t head_idx = 0; head_idx < curr_head_num; ++head_idx) {
float final_max = local_max[head_idx];
float curr_max = curr_split_max_buffer[head_idx];
float final_sum = local_sum[head_idx];
float curr_sum = curr_split_sum_buffer[head_idx];
float* __restrict__ non_scale_output_iter =
final_max > curr_max ? curr_output_buffer
: curr_split_output_buffer_iter;
float* __restrict__ scale_output_iter =
final_max > curr_max ? curr_split_output_buffer_iter
: curr_output_buffer;
float rescale_factor = final_max > curr_max ? curr_max - final_max
: final_max - curr_max;
#ifdef DEFINE_FAST_EXP
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
rescale_factor_vec = fast_exp(rescale_factor_vec);
rescale_factor = rescale_factor_vec.get_last_elem();
#else
rescale_factor = std::exp(rescale_factor);
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
#endif
local_sum[head_idx] = final_max > curr_max
? final_sum + rescale_factor * curr_sum
: rescale_factor * final_sum + curr_sum;
final_max = std::max(final_max, curr_max);
local_max[head_idx] = final_max;
for (int32_t i = 0; i < head_dim_group_num; ++i) {
vec_op::FP32Vec16 non_scale_vec(non_scale_output_iter);
vec_op::FP32Vec16 scale_vec(scale_output_iter);
vec_op::FP32Vec16 final_vec =
non_scale_vec + scale_vec * rescale_factor_vec;
final_vec.save(curr_output_buffer);
non_scale_output_iter += 16;
scale_output_iter += 16;
curr_output_buffer += 16;
}
curr_split_output_buffer_iter += head_dim;
}
} else {
vec_op::FP32Vec16 final_max(split_max_buffer);
final_max.save(local_max);
vec_op::FP32Vec16 final_sum(split_sum_buffer);
final_sum.save(local_sum);
}
curr_split_output_buffer += head_num_per_split * head_dim;
curr_split_max_buffer += head_num_per_split;
curr_split_sum_buffer += head_num_per_split;
}
// write back final max and sum
for (int32_t i = 0; i < curr_head_num; ++i) {
split_max_buffer[i] = local_max[i];
split_sum_buffer[i] = local_sum[i];
}
}
void partial_output(float* __restrict__ partial_output_buffer,
float* __restrict__ partial_max_buffer,
float* __restrict__ partial_sum_buffer,
int32_t curr_head_num,
float* __restrict__ split_output_buffer,
float* __restrict__ split_max_buffer,
float* __restrict__ split_sum_buffer,
volatile bool* __restrict__ flag) {
float* __restrict__ curr_partial_output_buffer = partial_output_buffer;
float* __restrict__ curr_split_output_buffer = split_output_buffer;
constexpr int32_t head_dim_group_num = head_dim / 16;
for (int32_t i = 0; i < curr_head_num; ++i) {
split_max_buffer[i] = partial_max_buffer[i];
split_sum_buffer[i] = partial_sum_buffer[i];
for (int32_t j = 0; j < head_dim_group_num; ++j) {
vec_op::FP32Vec16 vec(curr_partial_output_buffer);
vec.save(curr_split_output_buffer);
curr_partial_output_buffer += 16;
curr_split_output_buffer += 16;
}
}
std::atomic_thread_fence(std::memory_order_release);
*flag = true;
}
void final_output(float* __restrict__ partial_q_buffer,
query_t* __restrict__ curr_output_buffer,
float* __restrict__ sum_buffer,
const int32_t q_heads_per_kv,
const int32_t actual_q_token_num,
const int32_t q_head_num) {
// final output
using output_vec_t = typename VecTypeTrait<query_t>::vec_t;
float* __restrict__ curr_partial_output_buffer = partial_q_buffer;
float* __restrict__ curr_sum_buffer = sum_buffer;
constexpr int32_t group_num_per_head = head_dim / 16;
const int32_t partial_q_buffer_stride = q_heads_per_kv * head_dim;
const int32_t output_buffer_stride = q_head_num * head_dim;
for (int32_t token_idx = 0; token_idx < actual_q_token_num; ++token_idx) {
float* __restrict__ curr_partial_output_buffer_iter =
curr_partial_output_buffer;
query_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
vec_op::FP32Vec16 inv_sum_scale_vec(1.0 / *curr_sum_buffer);
for (int32_t i = 0; i < group_num_per_head; ++i) {
vec_op::FP32Vec16 vec(curr_partial_output_buffer_iter);
// divide the final sum val of softmax here
vec = inv_sum_scale_vec * vec;
// cast to query type
output_vec_t output_vec(vec);
output_vec.save(curr_output_buffer_iter);
// update
curr_partial_output_buffer_iter += 16;
curr_output_buffer_iter += 16;
}
// update
curr_sum_buffer += 1;
}
// update
curr_partial_output_buffer += partial_q_buffer_stride;
curr_output_buffer += output_buffer_stride;
}
}
};
} // namespace cpu_attention
#endif
#ifndef CPU_ATTN_MACROS_H
#define CPU_ATTN_MACROS_H
// x86_64
#ifdef __x86_64__
#define FAST_SPINNING _mm_pause();
#ifdef __AVX512F__
#define DEFINE_FAST_EXP \
const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); \
const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); \
const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); \
const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); \
const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); \
const __m512 vec_exp_log2ef = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); \
const __m512 vec_half = _mm512_set1_ps(0.5f); \
const __m512 vec_one = _mm512_set1_ps(1.f); \
const __m512 vec_zero = _mm512_set1_ps(0.f); \
const __m512 vec_two = _mm512_set1_ps(2.f); \
const __m512 vec_ln2f = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); \
const __m512 vec_ln_flt_min = \
_mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); \
const __m512 vec_ln_flt_max = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
const int n_mantissa_bits = 23; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \
always_inline)) { \
__m512 values = vec.reg; \
auto less_ln_flt_min_mask = \
_mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); \
auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); \
vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); \
auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); \
auto vec_fx_i = _mm512_cvt_roundps_epi32( \
vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); \
vec_fx = _mm512_cvtepi32_ps(vec_fx_i); \
auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); \
auto vec_res = \
_mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); \
auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); \
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); \
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); \
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); \
auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); \
vec_two_pow_n = _mm512_mask_blend_ps(less_ln_flt_min_mask, \
vec_two_pow_n, vec_zero); \
vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); \
vec_res = _mm512_mul_ps(vec_res, vec_two); \
vec_op::FP32Vec16 res(vec_res); \
return res; \
};
#endif
#endif
#endif
\ No newline at end of file
#ifndef CPU_ATTN_VEC_HPP
#define CPU_ATTN_VEC_HPP
#include "cpu_attn_impl.hpp"
namespace cpu_attention {
namespace {
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
template <typename kv_cache_t>
class TileGemm82 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
switch (m_size) {
case 1:
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 2:
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 3:
case 4:
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 5:
case 6:
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 7:
case 8:
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
}
}
template <int32_t M>
static void gemm_micro(float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size, const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(0 < M <= 8);
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
kv_cache_t* __restrict__ curr_b_0 = b_tile;
kv_cache_t* __restrict__ curr_b_1 = b_tile + 16;
float* __restrict__ curr_c_0 = c_tile;
float* __restrict__ curr_c_1 = c_tile + 16;
vec_op::FP32Vec16 c_regs[M * 2];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
float* __restrict__ curr_m_c_1 = curr_c_1;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
// update
curr_m_c_0 += ldc;
curr_m_c_1 += ldc;
});
}
float* __restrict__ curr_a = a_tile;
for (int32_t k = 0; k < dynamic_k_size; ++k) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
load_vec_t b_1_reg(curr_b_1);
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
float* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
float v = *curr_m_a;
vec_op::FP32Vec16 a_reg(v);
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += ldb;
curr_b_1 += ldb;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2].save(curr_c_0);
c_regs[i * 2 + 1].save(curr_c_1);
// update
curr_c_0 += ldc;
curr_c_1 += ldc;
});
}
};
} // namespace
// This is a general but naive implementation based on vector instructions
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
32; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
32; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 8;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VEC;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemm82<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
// Copy q to q_buffer and cast it to fp32
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
float* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
static_assert(head_dim % 16 == 0);
constexpr int32_t unroll_size = head_dim / 16;
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
scalar_t* __restrict__ curr_q =
src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
float* __restrict__ curr_q_buffer =
q_buffer + q_num_idx * q_heads_per_kv * head_dim +
q_head_idx * head_dim;
vec_op::unroll_loop<int32_t, unroll_size>([&](int32_t i) {
load_vec_t vec(curr_q);
vec_op::FP32Vec16 fp32_vec(vec);
fp32_vec = fp32_vec * scale_vec;
fp32_vec.save(curr_q_buffer);
curr_q += 16;
curr_q_buffer += 16;
});
}
}
}
// reshape K as column-major and V as row-major
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key as column-major
const scalar_t* key_start_ptr = key +
token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_cache_start_ptr =
key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_cache_start_ptr[j] = key_start_ptr[i];
}
}
{
// Write Value as row-major
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset * head_dim;
std::memcpy(value_cache_start_ptr, value_start_ptr,
sizeof(scalar_t) * head_dim);
}
}
}
}
};
} // namespace cpu_attention
#endif
#ifndef CPU_ATTN_VEC16_HPP
#define CPU_ATTN_VEC16_HPP
#include "cpu_attn_vec.hpp"
namespace cpu_attention {
namespace {
// 16-1-16 pattern, 16 regs for A, 1 regs for B, 16 regs for C, [16, K] @ [k,
// 16]
template <typename kv_cache_t>
class TileGemm161 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
switch (m_size) {
case 1:
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 2:
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 3:
case 4:
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 5:
case 6:
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 7:
case 8:
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 9:
case 10:
case 11:
case 12:
gemm_micro<12>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 13:
case 14:
case 15:
case 16:
gemm_micro<16>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
}
}
template <int32_t M>
static void gemm_micro(float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size, const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(0 < M <= 16);
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
kv_cache_t* __restrict__ curr_b_0 = b_tile;
float* __restrict__ curr_c_0 = c_tile;
vec_op::FP32Vec16 c_regs[M];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i] = vec_op::FP32Vec16(curr_m_c_0);
// update
curr_m_c_0 += ldc;
});
}
float* __restrict__ curr_a = a_tile;
for (int32_t k = 0; k < dynamic_k_size; ++k) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
float* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
float v = *curr_m_a;
vec_op::FP32Vec16 a_reg(v);
c_regs[i] = c_regs[i] + a_reg * fp32_b_0_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += ldb;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i].save(curr_c_0);
// update
curr_c_0 += ldc;
});
}
};
} // namespace
// This is a general but naive implementation based on vector instructions
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VEC16, scalar_t, head_dim>
: public AttentionImpl<ISA::VEC, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
16; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
16; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 16;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VEC16;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemm161<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
};
} // namespace cpu_attention
#endif
......@@ -40,6 +40,23 @@ namespace vec_op {
#define FORCE_INLINE __attribute__((always_inline)) inline
// Function to get the timestamp using RDTSCP
FORCE_INLINE uint64_t bench_timestamp() {
unsigned int cycles_low, cycles_high;
asm volatile(
".intel_syntax noprefix\n\t"
"CPUID\n\t" // Serialize instruction stream to ensure previous
// instructions complete
"RDTSCP\n\t" // Read TSC and core ID
"mov %0, edx\n\t" // Store high 32 bits of TSC
"mov %1, eax\n\t" // Store low 32 bits of TSC
".att_syntax"
: "=r"(cycles_high), "=r"(cycles_low)::"rax", "rbx", "rcx",
"rdx" // Clobbered registers
);
return (uint64_t)cycles_high << 32 | cycles_low;
}
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
......@@ -407,6 +424,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
......@@ -446,9 +465,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
explicit FP32Vec16(const FP32Vec16& data)
: reg_low(data.reg_low), reg_high(data.reg_high) {}
explicit FP32Vec16(const FP32Vec4& data)
: reg_low((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
......@@ -504,6 +520,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
_mm256_div_ps(reg_high, b.reg_high));
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(_mm256_max_ps(reg_low, b.reg_low),
_mm256_max_ps(reg_high, b.reg_high));
}
float reduce_max() const {
__m256 v = _mm256_max_ps(reg_low, reg_high);
// Permute to compare elements within 128-bit lanes
__m256 v_shuffled = _mm256_permute_ps(
v, 0b00001011); // Swap halves within each 128-bit lane
__m256 v_max = _mm256_max_ps(v, v_shuffled);
v_shuffled = _mm256_permute_ps(
v_max, 0b00000001); // Shuffle elements within each 128-bit lane
v_max = _mm256_max_ps(v_max, v_shuffled);
// Permute to compare elements between 128-bit lanes
v_shuffled =
_mm256_permute2f128_ps(v_max, v_max, 0b00000001); // Swap 128-bit lanes
v_max = _mm256_max_ps(v_max, v_shuffled);
// At this point, the maximum value is present in all elements of v_max.
// Extract the first element for the scalar result.
return _mm256_cvtss_f32(v_max); // Extract the lowest 32-bit float
}
float reduce_sum() const {
FP32Vec8 low = FP32Vec8(reg_low);
FP32Vec8 high = FP32Vec8(reg_high);
......@@ -642,7 +684,7 @@ inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
: reg(_mm256_insertf128_si256(
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
FP16Vec8(FP32Vec8(v.reg_high)).reg, 1)) {}
#endif
#ifdef __AVX512BF16__
......
......@@ -5,6 +5,7 @@
#include "common/memory.hpp"
#include "dnnl_helper.h"
#include "scratchpad_manager.h"
static dnnl::engine& default_engine() {
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
......@@ -22,23 +23,6 @@ void release_dnnl_matmul_handler(int64_t handler) {
delete ptr;
}
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
template <typename KT, typename VT>
class DNNLPrimitiveCache {
public:
......
......@@ -59,30 +59,6 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() {
return DNNLType<std::decay_t<T>>::type;
}
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
class DNNLMatMulPrimitiveHandler {
public:
virtual ~DNNLMatMulPrimitiveHandler() = default;
......
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
#endif
......@@ -192,7 +192,7 @@ class SHMManager {
const int group_size)
: _rank(rank),
_group_size(group_size),
_thread_num(torch::get_num_threads()),
_thread_num(omp_get_max_threads()),
_shm_names({""}),
_shared_mem_ptrs({nullptr}),
_shm_ctx(nullptr) {
......
......@@ -74,25 +74,35 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni);
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
const torch::Tensor& seq_lens, at::ScalarType dtype,
const torch::Tensor& query_start_loc, const bool casual,
const int64_t window_size, const std::string& isa_hint,
const bool enable_kv_split);
void cpu_attn_reshape_and_cache(const torch::Tensor& key,
const torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
const torch::Tensor& slot_mapping,
const std::string& isa);
void cpu_attention_with_kv_cache(
const torch::Tensor& query, const torch::Tensor& key_cache,
const torch::Tensor& value_cache, torch::Tensor& output,
const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
const double scale, const bool causal,
const std::optional<torch::Tensor>& alibi_slopes,
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, const double softcap,
const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
ops.def(
"dynamic_4bit_int_moe("
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
......@@ -102,20 +112,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
// Activation ops
// Activation function used in SwiGLU.
......@@ -259,37 +255,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
&int8_scaled_mm_with_quant);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
// CPU attention kernels
ops.def(
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
"int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
"query_start_loc, bool casual, int window_size, str isa_hint, bool "
"enable_kv_split) -> Tensor",
&get_scheduler_metadata);
ops.def(
"cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
"key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
"isa) -> ()",
&cpu_attn_reshape_and_cache);
ops.def(
"cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
"value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
"seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
"float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()",
&cpu_attention_with_kv_cache);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
......
......@@ -17,6 +17,7 @@
# VLLM_CPU_DISABLE_AVX512=false (default)|true
# VLLM_CPU_AVX512BF16=false (default)|true
# VLLM_CPU_AVX512VNNI=false (default)|true
# VLLM_CPU_AMXBF16=false (default)|true
#
######################### COMMON BASE IMAGE #########################
......@@ -92,6 +93,9 @@ ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
# Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ...
ARG VLLM_CPU_AVX512VNNI=0
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
# Support for building with AMXBF16 ISA: docker build --build-arg VLLM_CPU_AMXBF16="true" ...
ARG VLLM_CPU_AMXBF16=0
ENV VLLM_CPU_AMXBF16=${VLLM_CPU_AMXBF16}
WORKDIR /workspace/vllm
......
......@@ -171,6 +171,8 @@ This value is 4GB by default. Larger space can support more concurrent requests,
First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
Use multiples of 32 as `--block-size`, which is 128 by default.
Inference batch size is an important parameter for the performance. A larger batch usually provides higher throughput, a smaller batch provides lower latency. Tuning the max batch size starting from the default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment