Commit fbeb8a6f authored by raojy's avatar raojy
Browse files

raw_vllm

parent 2ca8867f
Pipeline #3454 canceled with stages
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <cuda.h>
#include <torch/all.h>
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
int64_t range_end, int64_t block_size,
int64_t input_block_count, int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[current_block_count++] = idx;
}
return current_block_count;
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* This function builds the index of each row of blocks from vertical indices
* and slash indices. The vertical indices are treated as points, while the
* slash indices are converted as ranges. The output consists of the merged
* ranges and separate column indices, where the ranges are represented by
* block indices.
*
* The implementation is referenced from the original MInference repo:
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
*/
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
causal);
}
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
// above is buffer size, use to compute offset)
NNZ_S = per_head_slash_topkv[head_idx];
NNZ_V = per_head_vertical_topkv[head_idx];
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* per_head_vertical_topkv, int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* Like the above convert_vertical_slash_indexes, but with
* pre-computed vertical and slash counts.
*/
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, // [N_HEADS, ]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64_mergehead(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
vertical_indices_count.data_ptr<int>(),
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
}
#pragma once
#include <torch/all.h>
#include <c10/util/Optional.h>
#include <map>
#include <vector>
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
void concat_and_cache_mla_rope_fused(
torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe,
torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox,
torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache,
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale);
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt);
// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
void cp_gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/Optional.h>
#include "cuda_utils.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
#else
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#include <algorithm>
#include <cassert>
#include <cfloat>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#endif
#if defined(__gfx942__)
constexpr float kFp8ScaleDivisor = 224.f;
#else
constexpr float kFp8ScaleDivisor = 448.f;
#endif
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost;
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
memcpy_type = cudaMemcpyHostToDevice;
} else {
TORCH_CHECK(false, "Invalid device combination");
}
// NOTE(youkaichao): keep in mind that `block_mapping` should be
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
const int64_t num_blocks = block_mapping.size(0);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
block_size_in_bytes, memcpy_type, stream);
}
}
namespace vllm {
// Grid: (num_layers, num_pairs)
template <typename scalar_t>
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs,
const int64_t* __restrict__ block_mapping,
const int numel_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
const int64_t src_block_offset = src_block_number * numel_per_block;
const int64_t dst_block_offset = dst_block_number * numel_per_block;
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int64_t src_offset = src_block_offset + i;
int64_t dst_offset = dst_block_offset + i;
key_cache[dst_offset] = key_cache[src_offset];
}
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int64_t src_offset = src_block_offset + i;
int64_t dst_offset = dst_block_offset + i;
value_cache[dst_offset] = value_cache[src_offset];
}
}
// Kernel for MLA, which works on a single joint kv_cache
// Grid: (num_layers, num_pairs)
template <typename scalar_t>
__global__ void copy_blocks_mla_kernel(
int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
const int mem_footprint_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
int64_t src_block = block_mapping[2 * pair_idx];
int64_t dst_block = block_mapping[2 * pair_idx + 1];
int64_t src_offset = src_block * mem_footprint_per_block;
int64_t dst_offset = dst_block * mem_footprint_per_block;
for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
cache[dst_offset + i] = cache[src_offset + i];
}
}
} // namespace vllm
namespace vllm {
// Used to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
}
}
};
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
// block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x,
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int h_block_count = head_size / x; // head_size//x
const int h_block_idx = threadIdx.x;
if (h_block_idx >= num_heads * h_block_count) {
return;
}
const int head_idx = h_block_idx / h_block_count;
const int h_block = h_block_idx % h_block_count;
const scalar_t* __restrict__ key_src =
key + token_idx * key_stride + head_idx * head_size + h_block * x;
const int64_t src_value_start =
token_idx * value_stride + head_idx * head_size + h_block * x;
cache_t* __restrict__ key_dst =
key_cache + block_idx * num_heads * h_block_count * block_size * x +
head_idx * h_block_count * block_size * x + h_block * block_size * x +
block_offset * x;
const int64_t tgt_value_start =
block_idx * num_heads * h_block_count * x * block_size +
head_idx * h_block_count * x * block_size + h_block * x * block_size +
block_offset;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, x, 0, 1, k_op);
const scalar_t* __restrict__ value_src = value + src_value_start;
cache_t* __restrict__ value_dst = value_cache + tgt_value_start;
#pragma unroll
for (int i = 0; i < x; i++) {
v_op(value_dst[i * block_size], value_src[i]);
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // NHD or HND, shape see comments below
cache_t* __restrict__ value_cache, // same above
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale,
const int kv_scale_stride) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n_elems = num_heads * head_size;
// pointers to the beginning of the source row for this token.
const scalar_t* __restrict__ key_src = key + token_idx * key_stride;
const scalar_t* __restrict__ value_src = value + token_idx * value_stride;
// find the start position inside the kv-cache for this token.
cache_t* __restrict__ key_dst =
key_cache + block_idx * block_stride + block_offset * page_stride;
cache_t* __restrict__ value_dst =
value_cache + block_idx * block_stride + block_offset * page_stride;
// this is true for the NHD layout where `head_stride == head_size`
const bool is_contiguous_heads = (head_stride == head_size);
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
if (is_contiguous_heads && kv_scale_stride == 0) {
// NHD layout and k/v_scales are [1] (i.e. single scale for all heads)
// kv cache: [num_blocks, block_size, num_heads, head_size]
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
blockDim.x, k_op);
vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
threadIdx.x, blockDim.x, v_op);
} else {
// HND layout OR k/v_scales are [num_heads] (i.e. per-attn-head)
// HND layout: heads are strided, but each head_size segment is contiguous
// kv cache: [num_blocks, num_heads, block_size, head_size]
const int lane = threadIdx.x & 31; // 0..31 within warp
const int warp_id = threadIdx.x >> 5; // warp index within block
const int warps_per_block = blockDim.x >> 5;
for (int head = warp_id; head < num_heads; head += warps_per_block) {
const scalar_t* __restrict__ k_src_h = key_src + head * head_size;
const scalar_t* __restrict__ v_src_h = value_src + head * head_size;
cache_t* __restrict__ k_dst_h =
key_dst + static_cast<int64_t>(head) * head_stride;
cache_t* __restrict__ v_dst_h =
value_dst + static_cast<int64_t>(head) * head_stride;
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: k_scale[head * kv_scale_stride];
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: v_scale[head * kv_scale_stride];
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
// within each head, let the 32 threads of the warp perform the vector
// copy
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
k_op);
vectorize_with_alignment<VEC_SIZE>(v_src_h, v_dst_h, head_size, lane, 32,
v_op);
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
}
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_ds_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int64_t dst_idx_start =
block_idx * block_stride + block_offset * entry_stride;
// For the NoPE part, each tile of 128 elements is handled by half of one warp
// (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// The RoPE part (last 64 elements) is handled by another 1 warp (32 threads).
// So in total, we use 3 warps (96 threads) per block.
// Cast kv_cache to 16_bit for RoPE values
scalar_t* kv_cache_16bit =
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
// The last warp handles the RoPE part
if (threadIdx.x >= 64) {
// Each thread handles two elements of RoPE
const int8_t pe_idx_start = (threadIdx.x - 64) * 2;
const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start;
// Vectorized load of two 16-bit values, performed as one 32-bit load
const int32_t vals = *reinterpret_cast<const int32_t*>(&k_pe[src_idx]);
// RoPE values start after the packed 8-bit NoPE values and the
// 32-bit scales
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start;
// Vectorized store of two 16-bit values, performed as one 32-bit store
*reinterpret_cast<int32_t*>(&kv_cache_16bit[dst_idx]) = vals;
return;
}
// The first two warps handle the NoPE part
const int8_t warp_idx = threadIdx.x >> 5;
const int8_t lane_idx = threadIdx.x & 31;
const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4);
// Each thread handles 8 elements of NoPE
// Load the NoPE elements for this thread into registers
const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8);
// Vectorized load of eight 16-bit values, performed as an int4 load
const int4 vals_i4 = *reinterpret_cast<const int4*>(&kv_c[src_idx_start]);
const scalar_t* vals = reinterpret_cast<const scalar_t*>(&vals_i4);
// Max absolute value of this thread's elements
float max_abs = fmaxf(fmaxf(fmaxf(fabsf(vals[0]), fabsf(vals[1])),
fmaxf(fabsf(vals[2]), fabsf(vals[3]))),
fmaxf(fmaxf(fabsf(vals[4]), fabsf(vals[5])),
fmaxf(fabsf(vals[6]), fabsf(vals[7]))));
// Warp-level reduction to find the max absolute value in each half-warp
#pragma unroll
for (int offset = 8; offset > 0; offset /= 2) {
max_abs = fmaxf(max_abs, VLLM_SHFL_XOR_SYNC_WIDTH(max_abs, offset, 16));
}
// Compute the scale for the tile
float tile_scale = fmaxf(max_abs / kFp8ScaleDivisor, FLT_MIN);
// The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) {
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
kv_cache_32bit[dst_idx] = tile_scale;
}
// Now all threads in the block scale and write their elements
// NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes)
const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8);
uint8_t result[8];
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] =
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
vals[i], tile_scale);
}
// Store as aligned 64-bit writes
*reinterpret_cast<uint64_t*>(&kv_cache[dst_idx_base]) =
*reinterpret_cast<const uint64_t*>(result);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void indexer_k_quant_and_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
float amax = 0.0f;
for (int i = 0; i < VEC_SIZE; i++) {
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
}
// Reduced amax
for (int mask = 16; mask > 0; mask /= 2) {
#ifdef USE_ROCM
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
#else
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
#endif
}
float scale = fmaxf(amax, 1e-4) / kFp8ScaleDivisor;
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
}
if (threadIdx.x == 0) {
const int64_t dst_scale_idx =
block_idx * cache_block_size * cache_stride +
cache_block_size * head_dim +
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
}
}
template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel(
const char* __restrict__ kv_cache, // [num_blocks, block_size,
// cache_stride]
char* __restrict__ dst_k, // [num_tokens, head_dim]
char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size *
// 4]
const int* __restrict__ block_table, // [batch_size, num_blocks]
const int* __restrict__ cu_seq_lens, // [batch_size + 1]
const int batch_size, // batch size
const int64_t token_stride, // stride for each token in dst_k
const int64_t head_dim, // dimension of each head
const int64_t block_stride, // stride for each block in kv_cache
const int64_t cache_token_stride, // stride for each token in kv_cache
const int64_t cache_block_size, // num_tokens for each block in kv_cache
const int num_blocks, // number of blocks
const int num_tokens, // number of tokens
const int quant_block_size // quantization block size
) {
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
const int token_idx = blockIdx.x * blockDim.y + threadIdx.y;
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
// Find batch index within a block
__shared__ int batch_idx[BLOCK_Y_SIZE];
for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x));
iter++) {
int tid = iter * blockDim.x + threadIdx.x;
if (tid < batch_size) {
const int seq_start = cu_seq_lens[tid];
const int seq_end = cu_seq_lens[tid + 1];
if (token_idx >= seq_start && token_idx < seq_end) {
batch_idx[threadIdx.y] = tid;
}
}
}
#ifndef USE_ROCM
__syncwarp();
#endif
if (head_idx >= head_dim || token_idx >= num_tokens) {
return;
}
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks +
inbatch_seq_idx / cache_block_size];
const int64_t src_block_offset = block_idx * block_stride;
const int64_t cache_inblock_offset =
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
;
if (threadIdx.x == 0) {
const int64_t src_scale_offset =
src_block_offset + cache_block_size * head_dim +
cache_inblock_offset * 4 / quant_block_size;
reinterpret_cast<float*>(dst_scale)[dst_inblock_offset / quant_block_size] =
reinterpret_cast<const float*>(kv_cache)[src_scale_offset / 4];
}
}
} // namespace vllm
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int head_div_x = head_size / x;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_div_x, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE);
}
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()), \
kv_scale_stride);
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, // [1] or [num_heads]
torch::Tensor& v_scale) { // [1] or [num_heads]
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(1);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int64_t block_stride = key_cache.stride(0);
int64_t page_stride = key_cache.stride(1);
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
TORCH_CHECK(k_scale.sizes() == v_scale.sizes(),
"k_scale and v_scale must have the same shape");
TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads,
"k_scale and v_scale must be of shape [1] or [num_heads]");
int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_FLASH);
}
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
if (kv_cache_dtype == "fp8_ds_mla") {
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
TORCH_CHECK(kv_c.itemsize() == 2,
"kv_c.itemsize() must be 2 for fp8_ds_mla");
TORCH_CHECK(k_pe.itemsize() == 2,
"k_pe.itemsize() must be 2 for fp8_ds_mla");
} else {
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
}
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "fp8_ds_mla") {
dim3 grid(num_tokens);
// For the NoPE part, each tile of 128 elements is handled by half of one
// warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// The RoPE part (last 64 elements) is handled by another 1 warp (32
// threads). So in total, we use 3 warps (96 threads) per block.
dim3 block(96);
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_DS_MLA);
} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}
}
namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const float scale,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
dst_cache[idx] =
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
}
}
} // namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
// Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device);
int64_t num_blocks = src_cache.size(0);
int64_t block_stride = src_cache.stride(0);
dim3 grid(num_blocks);
dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
}
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt,
int ENTRY_SIZE, int CTA_SIZE>
__global__ void gather_and_maybe_dequant_cache(
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK]
const int32_t num_tokens, const int32_t block_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
using ltype = vllm::vec_n_t<cache_t, vec_size>;
using stype = vllm::vec_n_t<scalar_t, vec_size>;
// We are adding this for code readability which will be optimized out when
// build in release.
assert(CTA_SIZE == blockDim.x);
#pragma unroll
for (int token_id = blockIdx.x; token_id < num_tokens;
token_id += gridDim.x) {
int64_t batch_id = token_to_seq[token_id];
int64_t batch_start = cu_seq_lens[batch_id];
int64_t batch_end = cu_seq_lens[batch_id + 1];
int32_t batch_offset = token_id - batch_start;
if (token_id >= batch_end) return;
int32_t offset = 0;
if (seq_starts != nullptr) {
offset = seq_starts[batch_id];
}
batch_offset += offset;
int32_t block_table_id = batch_offset / block_size;
int32_t slot_id = batch_offset % block_size;
int32_t block_table_offset = batch_id * block_table_stride + block_table_id;
int32_t block_id = block_table[block_table_offset];
int64_t cache_offset =
block_id * cache_block_stride + slot_id * cache_entry_stride;
constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size;
scalar_t* dst_ = dst + token_id * dst_entry_stride;
cache_t* src_ = const_cast<cache_t*>(src_cache) + cache_offset;
#pragma unroll
for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
reinterpret_cast<stype*>(dst_)[idx] =
static_cast<stype>(reinterpret_cast<ltype*>(src_)[idx]);
} else {
ltype loaded_val = reinterpret_cast<ltype*>(src_)[idx];
stype store_val;
#pragma unroll
for (int j = 0; j < vec_size; ++j) {
store_val.val[j] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(
loaded_val.val[j], *scale);
}
reinterpret_cast<stype*>(dst_)[idx] = store_val;
}
}
// process tail
constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size;
dst_ = dst_ + ENTRY_SIZE - tail_cnt;
src_ = src_ + ENTRY_SIZE - tail_cnt;
#pragma unroll
for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst_[idx] = static_cast<scalar_t>(src_[idx]);
} else {
dst_[idx] =
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(src_[idx], *scale);
}
}
}
}
} // namespace vllm
// Macro to dispatch the kernel based on the data type.
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, 576, \
thread_block_size> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
token_to_seq.data_ptr<int32_t>(), num_tokens, block_size, \
block_table_stride, cache_block_stride, cache_entry_stride, \
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - token_to_seq contains the back mapping from token_id to batch_id
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(-1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
"cu_seq_lens must be int32");
if (seq_starts.has_value()) {
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
TORCH_CHECK(head_dim == 576,
"gather_and_maybe_dequant_cache only support the head_dim to 576 "
"for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
constexpr int32_t thread_block_size = 64;
dim3 grid(num_tokens);
dim3 block(thread_block_size);
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
}
namespace vllm {
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ seq_lens, // [BATCH]
const int32_t* __restrict__ workspace_starts, // [BATCH]
const int32_t block_size, const int32_t head_dim,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = workspace_starts[bid];
const int32_t seq_len = seq_lens[bid];
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
const bool is_active_split = (split_start < tot_slots);
if (!is_active_split) return;
// Adjust the pointer for the block_table for this batch
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;
// Adjust dst pointer based on the cumulative sequence lengths
dst += seq_start * dst_entry_stride;
const int tid = threadIdx.x;
// Process each token in this split
for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
const uint8_t* token_ptr =
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}
// Move to next token
offset += 1;
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
}
template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
// block_size.
__global__ void cp_gather_cache(
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRY_SIZE]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const int32_t* __restrict__ seq_starts // Optional: starting offsets per
// batch
) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = cu_seq_lens[bid];
const int32_t seq_end = cu_seq_lens[bid + 1];
const int32_t seq_len = seq_end - seq_start;
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
const bool is_active_split = (split_start < tot_slots);
if (!is_active_split) return;
// Adjust the pointer for the block_table for this batch.
// If seq_starts is provided, compute an offset based on it
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
if (seq_starts != nullptr) {
offset += seq_starts[bid];
}
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;
// Adjust dst pointer based on the cumulative sequence lengths.
dst += seq_start * dst_entry_stride;
auto copy_entry = [&](const scalar_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
_dst[i] = _src[i];
};
for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + pid * dst_entry_stride;
copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr);
offset += 1;
// bump to next block
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
}
} // namespace vllm
// Macro to dispatch the kernel based on the data type.
#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \
vllm::cp_gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting slot index by
// seq_starts[bid]
void cp_gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t entry_size = src_cache.flatten(2, -1).size(2);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
"cu_seq_lens must be int32");
if (seq_starts.has_value()) {
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
// Decide on the number of splits based on the batch size.
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(1024);
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
"src_cache and dst must have the same dtype");
const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
if (dtype_bits == 32) {
CALL_CP_GATHER_CACHE(uint32_t);
} else if (dtype_bits == 16) {
CALL_CP_GATHER_CACHE(uint16_t);
} else if (dtype_bits == 8) {
CALL_CP_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
"workspace_starts must be int32");
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == seq_lens.device(),
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");
auto dtype = src_cache.scalar_type();
TORCH_CHECK(
dtype == at::ScalarType::Byte || // uint8
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
src_cache.dtype());
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
const uint8_t* src_ptr = nullptr;
if (dtype == at::ScalarType::Byte) {
src_ptr = src_cache.data_ptr<uint8_t>();
} else {
// float8_e4m3fn or float8_e5m2
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
}
// Decide on the number of splits based on the batch size
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(576);
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
}
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
cache_block_size, cache_stride, use_ue8m0);
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0,
"head_dim must be divisible by quant_block_size");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
(quant_block_size * vec_size));
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
static const std::string kv_cache_dtype = "fp8_e4m3";
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), kv_cache_dtype,
CALL_INDEXER_K_QUANT_AND_CACHE);
}
// Macro to dispatch the kernel based on the data amount.
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
(head_dim + 8 * vec_size - 1) / (8 * vec_size)), \
dim3(8, BLOCK_Y_SIZE), 0, stream>>>( \
reinterpret_cast<char*>(kv_cache.data_ptr()), \
reinterpret_cast<char*>(dst_k.data_ptr()), \
reinterpret_cast<char*>(dst_scale.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \
kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \
num_tokens, quant_block_size);
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens // [batch_size + 1]
) {
int batch_size = block_table.size(0);
int num_tokens = dst_k.size(0);
int head_dim = dst_k.size(1);
int quant_block_size = head_dim * 4 / dst_scale.size(1);
TORCH_CHECK(kv_cache.device() == dst_k.device(),
"kv_cache and dst_k must be on the same device");
TORCH_CHECK(kv_cache.device() == dst_scale.device(),
"kv_cache and dst_scale must be on the same device");
TORCH_CHECK(kv_cache.device() == block_table.device(),
"kv_cache and block_table must be on the same device");
TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(),
"kv_cache and cu_seq_lens must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0,
"head_dim must be divisible by quant_block_size");
constexpr int vec_size = 16;
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_tokens < 32) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
} else if (num_tokens < 64) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2);
} else if (num_tokens < 128) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4);
} else if (num_tokens < 256) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8);
} else if (num_tokens < 512) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16);
} else {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
}
}
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh"
#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
#else
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#endif
namespace vllm {
// NOTE Be EXTRA careful with raw_kv_scalar_t, for __half and __nv_bfloat16 it's
// using u16 as the backing type.
template <typename qk_t, bool IS_NEOX, typename raw_kv_scalar_t,
typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_rope_fused_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
qk_t* __restrict__ q_pe, // [num_tokens, num_q_heads, rot_dim]
qk_t* __restrict__ k_pe, // [num_tokens, rot_dim]
const qk_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const qk_t* __restrict__ rope_cos_sin_cache, // [max_position, 2,
// rot_dim // 2]
const int rot_dim, const int64_t q_pe_stride_token,
const int64_t q_pe_stride_head, const int64_t k_pe_stride,
const int64_t kv_c_stride, const int num_q_heads,
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// rot_dim)]
const int64_t* __restrict__ kv_cache_slot_mapping, // [num_tokens]
const int block_stride, const int entry_stride, const int kv_lora_rank,
const int block_size, const float* kv_cache_quant_scale) {
// Each thread block is responsible for one token.
const int64_t token_idx = blockIdx.x;
const int64_t pos = positions[token_idx];
const qk_t* cos_sin_ptr = rope_cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
// Q ROPE
const int nq = num_q_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
int head_idx = i / embed_dim;
int pair_idx = i % embed_dim;
// NOTE: Would be nice to have interleaved sin/cos so we could just load
// both at the same time.
qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx);
qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim);
qk_t* q_pe_head_ptr =
q_pe + token_idx * q_pe_stride_token + head_idx * q_pe_stride_head;
int pair_idx_x, pair_idx_y;
if constexpr (IS_NEOX) {
// GPT-NeoX style rotary embedding.
pair_idx_x = pair_idx;
pair_idx_y = embed_dim + pair_idx;
} else {
// GPT-J style rotary embedding.
pair_idx_x = pair_idx * 2;
pair_idx_y = pair_idx * 2 + 1;
}
qk_t x_src = q_pe_head_ptr[pair_idx_x];
qk_t y_src = q_pe_head_ptr[pair_idx_y];
qk_t x_dst = x_src * cos - y_src * sin;
qk_t y_dst = y_src * cos + x_src * sin;
q_pe_head_ptr[pair_idx_x] = x_dst;
q_pe_head_ptr[pair_idx_y] = y_dst;
}
const int64_t slot_idx = kv_cache_slot_mapping[token_idx];
const int64_t block_idx = slot_idx / block_size;
const int64_t entry_idx = slot_idx % block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
// K with 1 HEAD
for (int i = threadIdx.x; i < embed_dim; i += blockDim.x) {
int pair_idx = i;
qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx);
qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim);
qk_t* k_pe_head_ptr = k_pe + token_idx * k_pe_stride;
int pair_idx_x, pair_idx_y;
if constexpr (IS_NEOX) {
// GPT-NeoX style rotary embedding.
pair_idx_x = pair_idx;
pair_idx_y = embed_dim + pair_idx;
} else {
// GPT-J style rotary embedding.
pair_idx_x = pair_idx * 2;
pair_idx_y = pair_idx * 2 + 1;
}
qk_t x_src = k_pe_head_ptr[pair_idx_x];
qk_t y_src = k_pe_head_ptr[pair_idx_y];
qk_t x_dst = x_src * cos - y_src * sin;
qk_t y_dst = y_src * cos + x_src * sin;
k_pe_head_ptr[pair_idx_x] = x_dst;
k_pe_head_ptr[pair_idx_y] = y_dst;
// NOTE Why is this monster necessary?
// When K is of type float16, the actual template replacement for
// raw_kv_scalar_t with be u16. That's why it's used at the last moment
// otherwise CUDA ALU would break.
const raw_kv_scalar_t raw_x_value =
*reinterpret_cast<const raw_kv_scalar_t*>(&x_dst);
const raw_kv_scalar_t raw_y_value =
*reinterpret_cast<const raw_kv_scalar_t*>(&y_dst);
cache_t* kv_cache_ptr = kv_cache + block_idx * block_stride +
entry_idx * entry_stride + kv_lora_rank;
// MLA Cache Store
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
kv_cache_ptr[pair_idx_x] = raw_x_value;
kv_cache_ptr[pair_idx_y] = raw_y_value;
} else {
kv_cache_ptr[pair_idx_x] =
fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
raw_x_value, *kv_cache_quant_scale);
kv_cache_ptr[pair_idx_y] =
fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
raw_y_value, *kv_cache_quant_scale);
}
}
// NOPE
for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) {
const qk_t* src_ptr = kv_c + token_idx * kv_c_stride + i;
const raw_kv_scalar_t src_value =
*reinterpret_cast<const raw_kv_scalar_t*>(src_ptr);
cache_t* kv_cache_ptr =
kv_cache + block_idx * block_stride + entry_idx * entry_stride;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
kv_cache_ptr[i] = src_value;
} else {
kv_cache_ptr[i] = fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
src_value, *kv_cache_quant_scale);
}
}
}
} // namespace vllm
#define CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED(RAW_KV_T, CACHE_T, KV_DTYPE) \
do { \
VLLM_DISPATCH_FLOATING_TYPES(q_pe.scalar_type(), "qk_scalar_type", [&] { \
using qk_t = scalar_t; \
if (rope_is_neox) { \
vllm::concat_and_cache_mla_rope_fused_kernel<qk_t, true, RAW_KV_T, \
CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
rope_cos_sin_cache.data_ptr<qk_t>(), rot_dim, \
q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \
num_q_heads, reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
kv_cache_slot_mapping.data_ptr<int64_t>(), block_stride, \
entry_stride, kv_lora_rank, block_size, \
kv_cache_quant_scale.data_ptr<float>()); \
} else { \
vllm::concat_and_cache_mla_rope_fused_kernel<qk_t, false, RAW_KV_T, \
CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
rope_cos_sin_cache.data_ptr<qk_t>(), rot_dim, \
q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \
num_q_heads, reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
kv_cache_slot_mapping.data_ptr<int64_t>(), block_stride, \
entry_stride, kv_lora_rank, block_size, \
kv_cache_quant_scale.data_ptr<float>()); \
} \
}); \
} while (false)
// Executes RoPE on q_pe and k_pe, then writes k_pe and kv_c in the kv cache.
// q_pe and k_pe are modified in place.
// Replaces DeepseekScalingRotaryEmbedding.self.rotary_emb and
// concat_and_cache_mla.
void concat_and_cache_mla_rope_fused(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim]
torch::Tensor& k_pe, // [num_tokens, rot_dim]
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& rope_cos_sin_cache, // [max_position, rot_dim]
bool rope_is_neox,
torch::Tensor&
kv_cache_slot_mapping, // [num_tokens] or [num_actual_tokens]
torch::Tensor&
kv_cache, // [num_blocks, block_size, (kv_lora_rank + rot_dim)]
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale) {
const int64_t num_tokens = q_pe.size(0);
const int num_q_heads = q_pe.size(1);
const int rot_dim = q_pe.size(2);
const int kv_lora_rank = kv_c.size(1);
TORCH_CHECK(positions.size(0) >=
num_tokens); // CUDA Graphs might pad this for us
TORCH_CHECK_EQ(positions.dim(), 1);
TORCH_CHECK_EQ(positions.scalar_type(), c10::ScalarType::Long);
TORCH_CHECK_EQ(q_pe.size(0), num_tokens);
TORCH_CHECK_EQ(q_pe.size(1), num_q_heads);
TORCH_CHECK_EQ(q_pe.size(2), rot_dim);
TORCH_CHECK_EQ(q_pe.dim(), 3);
TORCH_CHECK_EQ(k_pe.size(0), num_tokens);
TORCH_CHECK_EQ(k_pe.size(1), rot_dim);
TORCH_CHECK_EQ(k_pe.dim(), 2);
TORCH_CHECK_EQ(k_pe.scalar_type(), q_pe.scalar_type());
TORCH_CHECK_EQ(kv_c.size(0), num_tokens);
TORCH_CHECK_EQ(kv_c.size(1), kv_lora_rank);
TORCH_CHECK_EQ(kv_c.dim(), 2);
TORCH_CHECK_EQ(kv_c.scalar_type(), q_pe.scalar_type());
TORCH_CHECK_EQ(kv_c.dtype(), q_pe.dtype());
TORCH_CHECK_EQ(rope_cos_sin_cache.size(1), rot_dim);
TORCH_CHECK_EQ(rope_cos_sin_cache.scalar_type(), q_pe.scalar_type());
TORCH_CHECK_EQ(kv_cache_slot_mapping.size(0), num_tokens);
TORCH_CHECK_EQ(kv_cache_slot_mapping.scalar_type(), c10::ScalarType::Long);
TORCH_CHECK_EQ(kv_cache.size(2), kv_lora_rank + rot_dim);
TORCH_CHECK_EQ(kv_cache.dim(), 3);
TORCH_CHECK_EQ(kv_cache_quant_scale.numel(), 1);
TORCH_CHECK_EQ(kv_cache_quant_scale.scalar_type(), c10::ScalarType::Float);
int64_t q_pe_stride_token = q_pe.stride(0);
int64_t q_pe_stride_head = q_pe.stride(1);
int64_t k_pe_stride = k_pe.stride(0);
int64_t kv_c_stride = kv_c.stride(0);
int block_size = kv_cache.size(1);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
int rope_block_size = std::min(num_q_heads * rot_dim / 2, 512);
int mla_block_size = kv_lora_rank;
int thread_block_size =
std::min(std::max(rope_block_size, mla_block_size), 512);
dim3 grid(num_tokens, 1, 1);
dim3 block(thread_block_size, 1, 1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(positions));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED);
}
#pragma once
#include <cstdlib>
#include <string>
#include <cctype>
namespace vllm {
// vllm_is_batch_invariant(); returns true
// if env VLLM_BATCH_INVARIANT=1
inline bool vllm_is_batch_invariant() {
static bool cached = []() {
std::string env_key = "VLLM_BATCH_INVARIANT";
const char* val = std::getenv(env_key.c_str());
return (val && std::atoi(val) != 0) ? 1 : 0;
}();
return cached;
}
} // namespace vllm
#pragma once
#define VLLM_IMPLIES(p, q) (!(p) || (q))
#pragma once
#include <climits>
#include <iostream>
inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <typename A, typename B>
static inline constexpr auto div_ceil(A a, B b) {
return (a + b - 1) / b;
}
// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_previous_multiple_of(T a, T b) {
return a % b == 0 ? a : (a / b) * b;
}
// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_next_multiple_of(T a, T b) {
return a % b == 0 ? a : ((a / b) + 1) * b;
}
#pragma once
#include <Python.h>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
#pragma once
// For TORCH_CHECK
#include <torch/library.h>
namespace vllm {
//
// ScalarType can represent a wide range of floating point and integer types,
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// The type definitions on the Python side can be found in: vllm/scalar_type.py
// these type definitions should be kept up to date with any Python API changes
// here.
//
class ScalarType {
public:
enum NanRepr : uint8_t {
NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
NAN_REPR_ID_MAX
};
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
int32_t bias, bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
signed_(signed_),
bias(bias),
finite_values_only(finite_values_only),
nan_repr(nan_repr) {};
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias);
}
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits, false, bias);
}
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(uint8_t exponent,
uint8_t mantissa) {
TORCH_CHECK(mantissa > 0 && exponent > 0);
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
bool finite_values_only,
NanRepr nan_repr) {
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
TORCH_CHECK(mantissa > 0 && exponent > 0);
TORCH_CHECK(nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions");
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
nan_repr);
}
uint8_t const exponent; // size of the exponent field (0 for integer types)
uint8_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t const bias; // stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types)
using Id = int64_t;
private:
// Field size in id
template <typename T_>
static constexpr size_t member_id_field_width() {
using T = std::decay_t<T_>;
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
}
template <typename Fn, typename Init, typename Member, typename... Rest>
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
Rest... rest) {
auto new_val = f(val, member);
if constexpr (sizeof...(rest) > 0) {
return reduce_members_helper(f, new_val, rest...);
} else {
return new_val;
};
}
template <typename Fn, typename Init>
constexpr auto reduce_members(Fn f, Init init) const {
// Should be in constructor order for `from_id`
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
finite_values_only, nan_repr);
};
template <typename Fn, typename Init>
static constexpr auto reduce_member_types(Fn f, Init init) {
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
return dummy_type.reduce_members(f, init);
};
static constexpr auto id_size_bits() {
return reduce_member_types(
[](int acc, auto member) -> int {
return acc + member_id_field_width<decltype(member)>();
},
0);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr Id id() const {
static_assert(id_size_bits() <= sizeof(Id) * 8,
"ScalarType id is too large to be stored");
auto or_and_advance = [](std::pair<Id, uint32_t> result,
auto member) -> std::pair<Id, uint32_t> {
auto [id, bit_offset] = result;
auto constexpr bits = member_id_field_width<decltype(member)>();
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
<< bit_offset,
bit_offset + bits};
};
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static constexpr ScalarType from_id(Id id) {
auto extract_and_advance = [id](auto result, auto member) {
using T = decltype(member);
auto [tuple, bit_offset] = result;
auto constexpr bits = member_id_field_width<T>();
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
((uint64_t(1) << bits) - 1));
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
};
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
std::pair<std::tuple<>, int>{});
return std::apply([](auto... args) { return ScalarType(args...); },
tuple_args);
}
constexpr int64_t size_bits() const {
return mantissa + exponent + is_signed();
}
constexpr bool is_signed() const { return signed_; }
constexpr bool is_integer() const { return exponent == 0; }
constexpr bool is_floating_point() const { return exponent > 0; }
constexpr bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false &&
nan_repr == NAN_IEEE_754;
}
constexpr bool has_nans() const {
return is_floating_point() && nan_repr != NAN_NONE;
}
constexpr bool has_infs() const {
return is_floating_point() && finite_values_only == false;
}
constexpr bool has_bias() const { return bias != 0; }
private:
double _floating_point_max() const {
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
"Cannot represent max/min as a double for type ", str());
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
max_mantissa -= 1;
}
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
TORCH_CHECK(exponent < 11,
"Cannot represent max/min as a double for type ", str());
max_exponent += 1;
}
// adjust the exponent to match that of a double
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
// is the exponent bits), there is some precedent for non-standard biases,
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
// but to avoid premature over complication we are just assuming the
// standard exponent bias until there is a need to support non-standard
// biases
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
uint64_t max_exponent_double =
max_exponent - exponent_bias + exponent_bias_double;
// shift the mantissa into the position for a double and
// the exponent
uint64_t double_raw =
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
return *reinterpret_cast<double*>(&double_raw);
}
constexpr std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
"Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}
constexpr std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
TORCH_CHECK(is_signed(),
"We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
} else {
TORCH_CHECK(!is_signed() || size_bits() <= 64,
"Cannot represent min as a int64_t");
if (is_signed()) {
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
// then perform an arithmetic shift right to set all the bits above
// (size_bits() - 1) to 1
return {INT64_MIN >> (64 - size_bits())};
} else {
return {int64_t(0)};
}
}
}
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> max() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_max());
}
// Min representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> min() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_min());
}
std::string str() const {
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading f) the scheme is:
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
* flags:
* - no-flags: means it follows IEEE 754 conventions
* - f: means finite values only (no infinities)
* - n: means nans are supported (non-standard encoding)
* for integer types the scheme is:
* `[u]int<size_bits>[b<bias>]`
* - if bias is not present it means its zero
*/
if (is_floating_point()) {
auto ret = "float" + std::to_string(size_bits()) + "_e" +
std::to_string(exponent) + "m" + std::to_string(mantissa);
if (!is_ieee_754()) {
if (finite_values_only) {
ret += "f";
}
if (nan_repr != NAN_NONE) {
ret += "n";
}
}
return ret;
} else {
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
if (has_bias()) {
ret += "b" + std::to_string(bias);
}
return ret;
}
}
constexpr bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent &&
bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only &&
nan_repr == other.nan_repr;
}
};
using ScalarTypeId = ScalarType::Id;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
static inline constexpr auto kS4 = ScalarType::int_(4);
static inline constexpr auto kU4 = ScalarType::uint(4);
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f =
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE8M0fnu =
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
// Fixed width style names, generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
static inline constexpr auto kInt4 = kS4;
static inline constexpr auto kUint4 = kU4;
static inline constexpr auto kUint4b8 = kU4B8;
static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
// colloquial names
static inline constexpr auto kHalf = kFE5M10;
static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;
static inline constexpr auto kFloat16Id = kFloat16.id();
}; // namespace vllm
#include "cpu_types.hpp"
namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
scalar_t* __restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(d % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
int start = i * d;
if constexpr (is_gated) {
start *= 2;
}
const scalar_vec_t x(input + start + j);
const vec_op::FP32Vec8 f32_x(x);
vec_op::FP32Vec8 f32_ans = func(f32_x);
if constexpr (is_gated) {
const scalar_vec_t y(input + start + d + j);
const vec_op::FP32Vec8 f32_y(y);
f32_ans = f32_y * f32_ans;
}
const scalar_vec_t result(f32_ans);
result.save(output + i * d + j);
}
}
}
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 x3 = x * x * x;
const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(1.702f);
return x / (ones + (zeros - w1 * x).exp());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5);
const vec_op::FP32Vec8 w3(0.044715);
const vec_op::FP32Vec8 x_3 = x * x * x;
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh());
}
}; // namespace
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
}
void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
}
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
activation_kernel<scalar_t, gelu_tanh_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
});
}
void gelu_new(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_new_impl)
activation_kernel<scalar_t, gelu_new_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_new_impl)
});
}
void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_fast_impl)
activation_kernel<scalar_t, gelu_fast_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
});
}
void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_quick_impl)
activation_kernel<scalar_t, gelu_quick_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
});
}
#ifndef CPU_ARCH_MACROS_H
#define CPU_ARCH_MACROS_H
// x86_64
#ifdef __x86_64__
#define FAST_SPINNING _mm_pause();
#ifdef __AVX512F__
#define DEFINE_FAST_EXP \
const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); \
const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); \
const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); \
const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); \
const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); \
const __m512 vec_exp_log2ef = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); \
const __m512 vec_half = _mm512_set1_ps(0.5f); \
const __m512 vec_one = _mm512_set1_ps(1.f); \
const __m512 vec_zero = _mm512_set1_ps(0.f); \
const __m512 vec_two = _mm512_set1_ps(2.f); \
const __m512 vec_ln2f = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); \
const __m512 vec_ln_flt_min = \
_mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); \
const __m512 vec_ln_flt_max = \
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
const int n_mantissa_bits = 23; \
auto fast_exp = [&](const vec_op::FP32Vec16& vec) __attribute__(( \
always_inline)) { \
__m512 values = vec.reg; \
auto less_ln_flt_min_mask = \
_mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); \
auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); \
vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); \
auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); \
auto vec_fx_i = _mm512_cvt_roundps_epi32( \
vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); \
vec_fx = _mm512_cvtepi32_ps(vec_fx_i); \
auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); \
auto vec_res = \
_mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); \
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); \
auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); \
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); \
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); \
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); \
auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); \
vec_two_pow_n = _mm512_mask_blend_ps(less_ln_flt_min_mask, \
vec_two_pow_n, vec_zero); \
vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); \
vec_res = _mm512_mul_ps(vec_res, vec_two); \
vec_op::FP32Vec16 res(vec_res); \
return res; \
};
#endif
#endif
#ifdef __aarch64__
// Implementation copied from Arm Optimized Routines (expf AdvSIMD)
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
#include <limits>
#define DEFINE_FAST_EXP \
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \
const float ln2_hi = 0x1.62e4p-1f; \
const float ln2_lo = 0x1.7f7d1cp-20f; \
const float c0 = 0x1.0e4020p-7f; \
const float c2 = 0x1.555e66p-3f; \
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \
const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \
const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \
const float32x4_t inf = \
vdupq_n_f32(std::numeric_limits<float>::infinity()); \
const float32x4_t zero = vdupq_n_f32(0.0f); \
auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \
r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \
float32x4_t r2 = vmulq_f32(r, r); \
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \
q = vfmaq_f32(q, p, r2); \
p = vmulq_f32(c4, r); \
float32x4_t poly = vfmaq_f32(p, q, r2); \
poly = vfmaq_f32(scale, poly, scale); \
const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \
const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \
poly = vbslq_f32(hi_mask, inf, poly); \
return vbslq_f32(lo_mask, zero, poly); \
}; \
auto fast_exp = [&](const vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \
float32x4x4_t result; \
result.val[0] = neon_expf(vec.reg.val[0]); \
result.val[1] = neon_expf(vec.reg.val[1]); \
result.val[2] = neon_expf(vec.reg.val[2]); \
result.val[3] = neon_expf(vec.reg.val[3]); \
return vec_op::FP32Vec16(result); \
};
#endif // __aarch64__
#endif
#include "cpu_attn_dispatch_generated.h"
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
const torch::Tensor& seq_lens, at::ScalarType dtype,
const torch::Tensor& query_start_loc, const bool casual,
const int64_t window_size, const std::string& isa_hint,
const bool enable_kv_split) {
cpu_attention::ISA isa;
if (isa_hint == "amx") {
isa = cpu_attention::ISA::AMX;
} else if (isa_hint == "vec") {
isa = cpu_attention::ISA::VEC;
} else if (isa_hint == "vec16") {
isa = cpu_attention::ISA::VEC16;
} else if (isa_hint == "neon") {
isa = cpu_attention::ISA::NEON;
} else if (isa_hint == "vxe") {
isa = cpu_attention::ISA::VXE;
} else {
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
}
cpu_attention::AttentionScheduler::ScheduleInput input;
input.num_reqs = num_req;
input.num_heads_q = num_heads_q;
input.num_heads_kv = num_heads_kv;
input.head_dim = head_dim;
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
input.seq_lens = seq_lens.data_ptr<int32_t>();
if (window_size != -1) {
input.left_sliding_window_size = window_size - 1;
if (casual) {
input.right_sliding_window_size = 0;
} else {
input.right_sliding_window_size = window_size - 1;
}
} else {
input.left_sliding_window_size = -1;
if (casual) {
input.right_sliding_window_size = 0;
} else {
input.right_sliding_window_size = -1;
}
}
input.casual = casual;
input.isa = isa;
input.enable_kv_split = enable_kv_split;
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
CPU_ATTN_DISPATCH(head_dim, isa, [&]() {
input.elem_size = sizeof(scalar_t);
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
input.output_buffer_elem_size =
sizeof(attn_impl::partial_output_buffer_t);
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
});
});
cpu_attention::AttentionScheduler scheduler;
torch::Tensor metadata = scheduler.schedule(input);
return metadata;
}
void cpu_attn_reshape_and_cache(
const torch::Tensor& key, // [token_num, head_num, head_size]
const torch::Tensor& value, // [token_num, head_num, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor& slot_mapping, const std::string& isa) {
TORCH_CHECK_EQ(key.dim(), 3);
TORCH_CHECK_EQ(value.dim(), 3);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
TORCH_CHECK_EQ(key.stride(2), 1);
TORCH_CHECK_EQ(value.stride(2), 1);
const int64_t token_num = key.size(0);
const int64_t key_token_num_stride = key.stride(0);
const int64_t value_token_num_stride = value.stride(0);
const int64_t head_num = value.size(1);
const int64_t key_head_num_stride = key.stride(1);
const int64_t value_head_num_stride = value.stride(1);
const int64_t num_blocks = key_cache.size(0);
const int64_t num_blocks_stride = key_cache.stride(0);
const int64_t cache_head_num_stride = key_cache.stride(1);
const int64_t block_size = key_cache.size(2);
const int64_t block_size_stride = key_cache.stride(2);
const int64_t head_dim = key.size(-1);
cpu_attention::ISA isa_tag = [&]() {
if (isa == "amx") {
return cpu_attention::ISA::AMX;
} else if (isa == "vec") {
return cpu_attention::ISA::VEC;
} else if (isa == "vec16") {
return cpu_attention::ISA::VEC16;
} else if (isa == "neon") {
return cpu_attention::ISA::NEON;
} else if (isa == "vxe") {
return cpu_attention::ISA::VXE;
} else {
TORCH_CHECK(false, "Invalid ISA type: " + isa);
}
}();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() {
attn_impl::reshape_and_cache(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), token_num, key_token_num_stride,
value_token_num_stride, head_num, key_head_num_stride,
value_head_num_stride, num_blocks, num_blocks_stride,
cache_head_num_stride, block_size, block_size_stride);
});
});
}
void cpu_attention_with_kv_cache(
const torch::Tensor& query, // [num_tokens, num_heads, head_size]
const torch::Tensor&
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor& output, // [num_tokens, num_heads, head_size]
const torch::Tensor& query_start_loc, // [num_tokens + 1]
const torch::Tensor& seq_lens, // [num_tokens]
const double scale, const bool causal,
const std::optional<torch::Tensor>& alibi_slopes, // [num_heads]
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, // [num_tokens, max_block_num]
const double softcap, const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux // [num_heads]
) {
TORCH_CHECK_EQ(query.dim(), 3);
TORCH_CHECK_EQ(query.stride(2), 1);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
cpu_attention::AttentionInput input;
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
scheduler_metadata.data_ptr());
input.num_tokens = query.size(0);
input.num_heads = query.size(1);
input.num_kv_heads = key_cache.size(1);
input.block_size = key_cache.size(2);
input.query = query.data_ptr();
input.query_num_tokens_stride = query.stride(0);
input.query_num_heads_stride = query.stride(1);
input.cache_num_blocks_stride = key_cache.stride(0);
input.cache_num_kv_heads_stride = key_cache.stride(1);
input.blt_num_tokens_stride = block_table.stride(0);
input.key_cache = key_cache.data_ptr();
input.value_cache = value_cache.data_ptr();
input.output = output.data_ptr();
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
input.seq_lens = seq_lens.data_ptr<int32_t>();
input.block_table = block_table.data_ptr<int32_t>();
input.alibi_slopes =
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
// For now sink must be bf16
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
input.scale = scale;
input.causal = causal;
input.sliding_window_left = sliding_window_left;
input.sliding_window_right = sliding_window_right;
if (input.causal) {
// to make boundary calculation easier
input.sliding_window_right = 0;
}
float softcap_fp32 = softcap;
input.softcap = softcap_fp32;
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() {
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
mainloop(&input);
});
});
}
#ifndef CPU_ATTN_AMX_HPP
#define CPU_ATTN_AMX_HPP
#include "cpu_attn_impl.hpp"
namespace cpu_attention {
namespace {
// AMX specific
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
typedef struct __tile_config {
uint8_t palette_id = 1;
uint8_t start_row = 0;
uint8_t reserved_0[14] = {0};
uint16_t colsb[16] = {0};
uint8_t rows[16] = {0};
} __tilecfg;
// 2-2-4 pattern, for 16 < m <= 32
// TILE 0, 1: load A matrix, row num should be 16, m - 16
// TILE 2, 3: load B matrix, row num should be 16
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
// - 16
template <typename kv_cache_t>
class TileGemm224 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
void* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
}
};
template <>
class TileGemm224<c10::BFloat16> {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
c10::BFloat16* __restrict__ a_tile,
c10::BFloat16* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
const int32_t k_times =
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM;
const int64_t a_tile_stride = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return AMX_TILE_ROW_BYTES;
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return lda * sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// k_cache is prepacked
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// v_cache is prepacked
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
// k_cache, v_cache are prepacked
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
// logits_buffer, output_buffer are not prepacked
float* __restrict__ c_tile_4 = c_tile;
float* __restrict__ c_tile_5 =
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc;
float* __restrict__ c_tile_7 =
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
const int32_t c_tile_stride = ldc * sizeof(float);
if (accum_c) {
_tile_loadd(4, c_tile_4, c_tile_stride);
_tile_loadd(5, c_tile_5, c_tile_stride);
_tile_loadd(6, c_tile_6, c_tile_stride);
_tile_loadd(7, c_tile_7, c_tile_stride);
} else {
_tile_zero(4);
_tile_zero(5);
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
_tile_dpbf16ps(4, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
_tile_dpbf16ps(5, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_dpbf16ps(6, 1, 2);
_tile_dpbf16ps(7, 1, 3);
// update ptrs
if constexpr (phase == AttentionGemmPhase::QK) {
// Q buffer is prepacked
a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// P buffer is not prepacked
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
_tile_stored(4, c_tile_4, c_tile_stride);
_tile_stored(5, c_tile_5, c_tile_stride);
_tile_stored(6, c_tile_6, c_tile_stride);
_tile_stored(7, c_tile_7, c_tile_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
const int32_t m_0 = AMX_TILE_ROW_NUM;
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
config.rows[0] = m_0;
config.rows[1] = m_1;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = m_0;
config.rows[5] = m_0;
config.rows[6] = m_1;
config.rows[7] = m_1;
_tile_loadconfig(&config);
}
};
// 1-2-2 pattern, for 0 < m <= 16
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
// m, m
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
// num should be 16
// TILE 6, 7, (6, 7): store results C matrix, row num should be
// m
template <typename kv_cache_t>
class TileGemm122 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
void* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
}
};
template <>
class TileGemm122<c10::BFloat16> {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
c10::BFloat16* __restrict__ a_tile,
c10::BFloat16* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
c10::BFloat16* __restrict__ a_tile_1 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
const int64_t a_tile_stride = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// q_buffer is prepacked
return AMX_TILE_ROW_BYTES;
} else if constexpr (phase == AttentionGemmPhase::PV) {
// logits_buffer is row-major
return lda * sizeof(c10::BFloat16);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
if constexpr (phase == AttentionGemmPhase::QK) {
// k_cache is prepacked
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// v_cache is prepacked
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
} else {
TORCH_CHECK(false, "Unreachable");
}
}();
c10::BFloat16* __restrict__ b_tile_4 =
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
c10::BFloat16* __restrict__ b_tile_5 =
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
int64_t b_stride = AMX_TILE_ROW_BYTES;
float* __restrict__ c_tile_6 = c_tile;
float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float);
int64_t c_stride = ldc * sizeof(float);
const int32_t k_times =
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
const int32_t k_group_times = k_times / 2;
const bool has_tail = (k_times % 2 == 1);
if (accum_c) {
_tile_loadd(6, c_tile_6, c_stride);
_tile_loadd(7, c_tile_7, c_stride);
} else {
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_group_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_stream_loadd(4, b_tile_4, b_stride);
_tile_dpbf16ps(6, 1, 4);
_tile_stream_loadd(5, b_tile_5, b_stride);
_tile_dpbf16ps(7, 1, 5);
// update ptrs
if constexpr (phase == AttentionGemmPhase::QK) {
// Q buffer is prepacked
a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
} else if constexpr (phase == AttentionGemmPhase::PV) {
// P buffer is not prepacked
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
}
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
if (has_tail) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
}
_tile_stored(6, c_tile_6, c_stride);
_tile_stored(7, c_tile_7, c_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
config.rows[0] = m;
config.rows[1] = m;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = AMX_TILE_ROW_NUM;
config.rows[5] = AMX_TILE_ROW_NUM;
config.rows[6] = m;
config.rows[7] = m;
_tile_loadconfig(&config);
}
};
} // namespace
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = scalar_t;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = scalar_t;
constexpr static int64_t BlockSizeAlignment =
AMX_TILE_ROW_BYTES /
sizeof(kv_cache_t); // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 32;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::AMX;
constexpr static bool scale_on_logits = true;
public:
AttentionImpl() : current_q_head_num_(0) {
// Use all columns in AMX tiles
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
}
~AttentionImpl() { _tile_release(); }
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
if (q_head_num > AMX_TILE_ROW_NUM) {
if (q_head_num != current_q_head_num_) {
current_q_head_num_ = q_head_num;
TileGemm224<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
}
attention<TileGemm224<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
} else {
if (q_head_num != current_q_head_num_) {
current_q_head_num_ = q_head_num;
TileGemm122<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
}
attention<TileGemm122<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment * head_dim;
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment * (AMX_TILE_ROW_BYTES / 4);
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return block_size * HeadDimAlignment;
}
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
scalar_t* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, const float scale) {
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
constexpr int64_t head_elem_num_pre_block =
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
int32_t idx = 0;
int8_t* __restrict__ q_buffer_iter = reinterpret_cast<int8_t*>(q_buffer);
for (int32_t q_num_idx = 0; q_num_idx < q_num;
++q_num_idx, src += q_num_stride) {
scalar_t* __restrict__ src_iter = src;
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv;
++q_head_idx, src_iter += q_head_stride) {
vec_op::unroll_loop<int32_t, head_size_block_num>(
[&](int32_t head_size_block_idx) {
// Use INT8Vec64 for 64 bytes block
vec_op::INT8Vec64 vec(src_iter + head_size_block_idx *
head_elem_num_pre_block);
vec.save(q_buffer_iter + head_size_block_idx * AMX_TILE_BYTES);
});
++idx;
q_buffer_iter += AMX_TILE_ROW_BYTES;
if ((idx & (AMX_TILE_ROW_NUM - 1)) == 0) {
// head is in another amx tile
q_buffer_iter -= AMX_TILE_ROW_NUM * AMX_TILE_ROW_BYTES;
q_buffer_iter += head_size_block_num * AMX_TILE_BYTES;
}
}
}
}
// reshape KV to AMX friendly layout
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
// For AMX 2D tiles, size of each line is 64 bytes
constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES;
// For AMX B martix, N always is 16
constexpr int64_t amx_b_tile_n_size = AMX_TILE_ROW_BYTES / 4;
constexpr int64_t amx_b_tile_k_size = amx_tile_row_size / sizeof(scalar_t);
// For now suppose block_size is divisible by amx_tile_column_num
TORCH_CHECK_EQ(block_size % amx_b_tile_k_size, 0);
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key
// Head elements should be packed as quand-words and stored in token
// groups with (quadword_stride/4) tokens
constexpr int64_t token_num_per_group = amx_tile_row_size / 4;
static_assert(head_dim % (4 / sizeof(scalar_t)) == 0);
constexpr int64_t quadword_num = head_dim / (4 / sizeof(scalar_t));
const int32_t* key_start_quadword_ptr =
reinterpret_cast<const int32_t*>(
key + token_idx * key_token_num_stride +
head_idx * key_head_num_stride);
const int64_t group_idx = block_offset / token_num_per_group;
const int64_t group_offset = block_offset % token_num_per_group;
constexpr int64_t quadword_num_per_group =
token_num_per_group * quadword_num;
int32_t* key_cache_start_ptr =
reinterpret_cast<int32_t*>(key_cache +
block_idx * num_blocks_stride +
head_idx * cache_head_num_stride) +
group_idx * quadword_num_per_group + group_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; j < quadword_num;
i += token_num_per_group, ++j) {
key_cache_start_ptr[i] = key_start_quadword_ptr[j];
}
}
{
// Write Value
// Different from Key, block_size dimension is packed rather than
// head_size dimension block_size dimension is packed as quand-words;
constexpr int64_t token_num_per_sub_group = 4 / sizeof(scalar_t);
const int64_t token_num_per_group = block_size;
constexpr int64_t head_elems_per_group = amx_b_tile_n_size;
const int64_t group_size = token_num_per_group * head_elems_per_group;
// For now suppose head_dim is divisible by amx_b_tile_n_size
static_assert(head_dim % head_elems_per_group == 0);
constexpr int64_t group_num = head_dim / head_elems_per_group;
const int64_t sub_group_idx = block_offset / token_num_per_sub_group;
const int64_t sub_group_offset =
block_offset % token_num_per_sub_group;
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride +
sub_group_idx * token_num_per_sub_group * amx_b_tile_n_size +
sub_group_offset;
for (int64_t i = 0; i < group_num; ++i) {
#pragma GCC unroll head_elems_per_group
for (int64_t j = 0, k = 0; j < head_elems_per_group;
++j, k += token_num_per_sub_group) {
value_cache_start_ptr[k] = value_start_ptr[j];
}
value_start_ptr += head_elems_per_group;
value_cache_start_ptr += group_size;
}
}
}
}
}
private:
alignas(64) __tilecfg amx_tile_config_;
int32_t current_q_head_num_;
};
} // namespace cpu_attention
#endif
#ifndef CPU_ATTN_HPP
#define CPU_ATTN_HPP
#include <type_traits>
#include <cstddef>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu/cpu_arch_macros.h"
#include "cpu/utils.hpp"
namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16, NEON, VXE };
template <ISA isa, typename scalar_t, int64_t head_dim>
class AttentionImpl {};
struct AttentionWorkItemGroup {
int32_t req_id;
int32_t q_token_id_start;
int32_t q_token_num;
int32_t kv_split_pos_start;
int32_t kv_split_pos_end;
int64_t total_kv_len;
int32_t split_id;
int32_t local_split_id;
AttentionWorkItemGroup(const int32_t req_id, const int32_t q_token_id_start,
const int32_t kv_split_pos_start,
const int32_t kv_split_pos_end)
: req_id(req_id),
q_token_id_start(q_token_id_start),
q_token_num(0),
kv_split_pos_start(kv_split_pos_start),
kv_split_pos_end(kv_split_pos_end),
total_kv_len(0),
split_id(-1),
local_split_id(0) {}
std::string to_string() const {
std::stringstream ss;
ss << '[' << "req_id: " << req_id << ",\n";
ss << "q_token_id_start: " << q_token_id_start << ",\n";
ss << "q_token_num: " << q_token_num << ",\n";
ss << "kv_split_pos_start: " << kv_split_pos_start << ",\n";
ss << "kv_split_pos_end: " << kv_split_pos_end << ",\n";
ss << "total_kv_len: " << total_kv_len << ",\n";
ss << "split_id: " << split_id << ",\n";
ss << "local_split_id: " << local_split_id << ",\n";
ss << ']';
return ss.str();
}
};
struct ReductionWorkItemGroup {
int32_t req_id;
int32_t q_token_id_start;
int32_t q_token_id_num;
int32_t split_start_id;
int32_t split_num;
ReductionWorkItemGroup(const int32_t req_id, const int32_t q_token_id_start,
const int32_t q_token_id_num,
const int32_t split_start_id)
: req_id(req_id),
q_token_id_start(q_token_id_start),
q_token_id_num(q_token_id_num),
split_start_id(split_start_id),
split_num(0) {}
std::string to_string() const {
std::stringstream ss;
ss << '[' << "req_id: " << req_id << ",\n";
ss << "q_token_id_start: " << q_token_id_start << ",\n";
ss << "q_token_id_num: " << q_token_id_num << ",\n";
ss << "split_start_id: " << split_start_id << ",\n";
ss << "split_num: " << split_num << ",\n";
ss << ']';
return ss.str();
}
};
struct AttentionMetadata {
std::atomic_int64_t counter;
char _padding1[56];
ISA isa;
int32_t workitem_group_num;
int32_t reduction_item_num;
int32_t reduction_split_num;
int32_t thread_num;
int32_t effective_thread_num; // non-zero item num in workitem_num_per_thread
int32_t split_kv_q_token_num_threshold;
int64_t attention_scratchpad_size_per_thread;
int64_t reduction_scratchpad_size_per_kv_head;
AttentionWorkItemGroup* workitem_groups_ptr;
ReductionWorkItemGroup* reduction_items_ptr;
int32_t cu_workitem_num_per_thread[1025] = {
0}; // prefix sum of workitem_num_per_thread
char _padding2[56];
AttentionMetadata(ISA isa, int32_t workitem_group_num,
int32_t reduction_item_num, int32_t reduction_split_num,
int32_t split_kv_q_token_num_threshold)
: isa(isa),
workitem_group_num(workitem_group_num),
reduction_item_num(reduction_item_num),
reduction_split_num(reduction_split_num),
thread_num(omp_get_max_threads()),
effective_thread_num(thread_num),
split_kv_q_token_num_threshold(split_kv_q_token_num_threshold),
attention_scratchpad_size_per_thread(0),
reduction_scratchpad_size_per_kv_head(0),
workitem_groups_ptr(
(AttentionWorkItemGroup*)((char*)this + sizeof(AttentionMetadata))),
reduction_items_ptr(
(ReductionWorkItemGroup*)((char*)this + sizeof(AttentionMetadata) +
workitem_group_num *
sizeof(AttentionWorkItemGroup))),
counter(0) {
TORCH_CHECK_LE(thread_num, 1024);
static_assert(sizeof(AttentionMetadata) % 64 == 0);
TORCH_CHECK(reinterpret_cast<size_t>(this) % 64 == 0);
}
void reset_counter() { counter.store(0); }
int64_t acquire_counter() { return counter++; }
void print() const {
std::stringstream ss;
ss << "ISA: ";
switch (isa) {
case ISA::AMX:
ss << "AMX, ";
break;
case ISA::VEC:
ss << "VEC, ";
break;
case ISA::VEC16:
ss << "VEC16, ";
break;
case ISA::NEON:
ss << "NEON, ";
break;
}
ss << "workitem_group_num: " << workitem_group_num
<< ", reduction_item_num: " << reduction_item_num
<< ", reduction_split_num: " << reduction_split_num
<< ", thread_num: " << thread_num
<< ", effective_thread_num: " << effective_thread_num
<< ", attention_scratchpad_size_per_thread: "
<< attention_scratchpad_size_per_thread
<< ", reduction_scratchpad_size_per_kv_head: "
<< reduction_scratchpad_size_per_kv_head << ", workitem groups:\n";
for (int32_t i = 0; i < workitem_group_num; ++i) {
ss << (workitem_groups_ptr + i)->to_string() << ",\n";
}
ss << "cu_workitem_num_per_thread: [";
for (int32_t i = 0; i < thread_num + 1; ++i) {
ss << cu_workitem_num_per_thread[i] << ", ";
}
ss << "]\n";
ss << "reduction items: \n";
for (int32_t i = 0; i < reduction_item_num; ++i) {
ss << (reduction_items_ptr + i)->to_string() << ",\n";
}
std::printf("%s", ss.str().c_str());
}
};
// Thread attention scratchpad contains:
// - Q: q_tile_size * head_dim * q_buffer_elem_size, gather Q heads, especially
// for GQA
// - Q@K^T: max_num_q_per_iter * k_tile_size * logits_buffer_elem_size, logits
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
// * q_tile_size * 4, partial output, max + sum (float)
// Reduction scratchpad contains:
// - flags: bool array to indicate whether the split is finished
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
// - max, sum: 2 * split_num * q_tile_size * 4
class AttentionScratchPad {
public:
AttentionScratchPad(int64_t thread_id,
const AttentionMetadata& attention_metadata,
void* scratchpad_ptr)
: thread_scratchpad_ptr(
static_cast<int8_t*>(scratchpad_ptr) +
thread_id *
attention_metadata.attention_scratchpad_size_per_thread),
reduction_scratchpad_ptr(
static_cast<int8_t*>(scratchpad_ptr) +
attention_metadata.thread_num *
attention_metadata.attention_scratchpad_size_per_thread),
reduction_scratchpad_size_per_kv_head(
attention_metadata.reduction_scratchpad_size_per_kv_head) {}
// for attention
void update(const int64_t head_dim, const int64_t q_buffer_elem_size,
const int64_t logits_buffer_elem_size,
const int64_t output_buffer_elem_size,
const int64_t max_num_q_per_iter, const int64_t q_head_tile_size,
const int64_t kv_tile_size) {
int64_t buffer_offset = 0;
q_buffer_offset_ = buffer_offset;
buffer_offset +=
calcu_q_buffer_size(q_head_tile_size, head_dim, q_buffer_elem_size);
logits_buffer_offset_ = buffer_offset;
buffer_offset += calcu_logits_buffer_size(max_num_q_per_iter, kv_tile_size,
logits_buffer_elem_size);
output_buffer_offset_ = buffer_offset;
buffer_offset += calcu_partial_output_buffer_size(
q_head_tile_size, head_dim, output_buffer_elem_size);
max_buffer_offset_ = buffer_offset;
buffer_offset += calcu_partial_output_max_sum_buffer_size(q_head_tile_size);
sum_buffer_offset_ = buffer_offset;
}
// for reduction
void update(const int32_t kv_head_idx, const int32_t total_split_num,
const int64_t head_dim, const int64_t q_head_tile_size,
const int64_t output_buffer_elem_size) {
int64_t buffer_offset = kv_head_idx * reduction_scratchpad_size_per_kv_head;
reduce_flag_buffer_offset_ = buffer_offset;
buffer_offset += calcu_reduce_flag_buffer_size(total_split_num);
reduce_output_buffer_offset_ = buffer_offset;
buffer_offset += calcu_reduce_output_buffer_size(
total_split_num, q_head_tile_size, head_dim, output_buffer_elem_size);
reduce_max_buffer_offset_ = buffer_offset;
buffer_offset +=
calcu_reduce_max_sum_buffer_size(total_split_num, q_head_tile_size);
reduce_sum_buffer_offset_ = buffer_offset;
}
template <typename T>
T* get_q_buffer() {
return reinterpret_cast<T*>(thread_scratchpad_ptr + q_buffer_offset_);
}
float* get_logits_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr +
logits_buffer_offset_);
}
float* get_output_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr +
output_buffer_offset_);
}
float* get_max_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr + max_buffer_offset_);
}
float* get_sum_buffer() {
return reinterpret_cast<float*>(thread_scratchpad_ptr + sum_buffer_offset_);
}
volatile bool* get_reduce_flag_buffer() {
return reinterpret_cast<volatile bool*>(reduction_scratchpad_ptr +
reduce_flag_buffer_offset_);
}
float* get_reduce_output_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_output_buffer_offset_);
}
float* get_reduce_max_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_max_buffer_offset_);
}
float* get_reduce_sum_buffer() {
return reinterpret_cast<float*>(reduction_scratchpad_ptr +
reduce_sum_buffer_offset_);
}
int64_t get_thread_scratchpad_size() const {
return 2 * sum_buffer_offset_ - max_buffer_offset_;
}
int64_t get_reduction_scratchpad_size() const {
return 2 * reduce_sum_buffer_offset_ - reduce_max_buffer_offset_;
}
private:
static int64_t round_to_64(const int64_t num) {
return ((num + 63) >> 6) << 6;
}
static int64_t calcu_q_buffer_size(const int64_t q_tile_size,
const int64_t head_dim,
const int64_t elem_size) {
return round_to_64(q_tile_size * head_dim * elem_size);
}
static int64_t calcu_logits_buffer_size(const int64_t max_num_q_per_iter,
const int64_t k_tile_size,
const int64_t elem_size) {
return round_to_64(elem_size * max_num_q_per_iter * k_tile_size);
}
static int64_t calcu_partial_output_buffer_size(const int64_t q_tile_size,
const int64_t head_dim,
const int64_t elem_size) {
return round_to_64(q_tile_size * head_dim * elem_size);
}
static int64_t calcu_partial_output_max_sum_buffer_size(
const int64_t q_tile_size) {
return round_to_64(q_tile_size * sizeof(float));
}
static int64_t calcu_reduce_flag_buffer_size(const int64_t total_split_num) {
return round_to_64(total_split_num * sizeof(bool));
}
static int64_t calcu_reduce_max_sum_buffer_size(
const int64_t total_split_num, const int32_t q_head_tile_size) {
return round_to_64(total_split_num * q_head_tile_size * sizeof(float));
}
static int64_t calcu_reduce_output_buffer_size(
const int64_t total_split_num, const int64_t q_head_tile_size,
const int64_t head_dim, const int64_t output_buffer_elem_size) {
return round_to_64(total_split_num * q_head_tile_size * head_dim *
output_buffer_elem_size);
}
private:
int8_t* thread_scratchpad_ptr;
int8_t* reduction_scratchpad_ptr;
int64_t reduction_scratchpad_size_per_kv_head;
// attention buffers
int64_t q_buffer_offset_;
int64_t logits_buffer_offset_;
int64_t output_buffer_offset_;
int64_t max_buffer_offset_;
int64_t sum_buffer_offset_;
// reduction buffers
int64_t reduce_flag_buffer_offset_;
int64_t reduce_output_buffer_offset_;
int64_t reduce_max_buffer_offset_;
int64_t reduce_sum_buffer_offset_;
};
class AttentionScheduler {
public:
struct ScheduleInput {
int32_t num_reqs;
int32_t elem_size;
int32_t q_buffer_elem_size;
int32_t logits_buffer_elem_size;
int32_t output_buffer_elem_size;
int32_t num_heads_q;
int32_t num_heads_kv;
int32_t head_dim;
int32_t* query_start_loc;
int32_t* seq_lens;
int32_t left_sliding_window_size;
int32_t right_sliding_window_size;
bool casual;
cpu_attention::ISA isa;
int32_t max_num_q_per_iter; // max Q head num can be hold in registers
int32_t kv_block_alignment; // context length alignment requirement
bool enable_kv_split;
};
static constexpr int32_t MaxQTileIterNum = 128;
AttentionScheduler()
: available_cache_size_(cpu_utils::get_available_l2_size()) {}
torch::Tensor schedule(const ScheduleInput& input) const {
const bool casual = input.casual;
const int32_t thread_num = omp_get_max_threads();
const int64_t cache_size = cpu_utils::get_available_l2_size();
const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
const int32_t kv_len_alignment = input.kv_block_alignment;
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
const bool use_gqa = (max_num_q_per_iter % q_head_per_kv == 0);
if (!use_gqa) {
q_head_per_kv = 1; // fallback to MHA
}
const int32_t min_split_kv_len =
((max_num_q_per_iter * 4 + kv_len_alignment - 1) / kv_len_alignment) *
kv_len_alignment;
const int32_t max_num_q_token_per_iter = max_num_q_per_iter / q_head_per_kv;
const int64_t default_tile_size = calcu_default_tile_size(
cache_size, input.head_dim, input.elem_size, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, max_num_q_per_iter);
const int32_t default_tile_token_num = default_tile_size / q_head_per_kv;
const int32_t split_kv_q_token_num_threshold =
input.enable_kv_split ? 1 : 0;
const int32_t left_sliding_window_size = input.left_sliding_window_size;
const int32_t right_sliding_window_size = input.right_sliding_window_size;
TORCH_CHECK_LE(split_kv_q_token_num_threshold * q_head_per_kv, 16);
// get total kv len
int64_t total_kv_len = 0;
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
const int32_t seq_len = input.seq_lens[req_id];
const int32_t q_token_num =
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
const int32_t q_start_pos = (casual ? (seq_len - q_token_num) : 0);
const int32_t kv_start_pos = 0;
const int32_t kv_end_pos = seq_len;
for (int32_t token_id = 0; token_id < q_token_num;
token_id += max_num_q_token_per_iter) {
const int32_t q_tile_token_num =
std::min(max_num_q_token_per_iter, q_token_num - token_id);
const int32_t q_tile_pos_left = q_start_pos + token_id;
const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num;
const auto [kv_tile_pos_left, kv_tile_pos_right] = calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_pos_left, q_tile_pos_right,
left_sliding_window_size, right_sliding_window_size);
const auto [aligned_kv_tile_pos_left, aligned_kv_tile_pos_right] =
align_kv_tile_pos(kv_tile_pos_left, kv_tile_pos_right,
kv_len_alignment);
int32_t curr_kv_len =
aligned_kv_tile_pos_right - aligned_kv_tile_pos_left;
total_kv_len += curr_kv_len;
}
}
const int64_t kv_len_per_thread =
(((total_kv_len / thread_num) + kv_len_alignment - 1) /
kv_len_alignment) *
kv_len_alignment * (use_gqa ? input.num_heads_kv : input.num_heads_q);
std::vector<AttentionWorkItemGroup> workitems;
std::vector<ReductionWorkItemGroup> reduce_workitems;
workitems.reserve(1024);
reduce_workitems.reserve(1024);
std::vector<int32_t> workitem_num_per_thread(thread_num, 0);
// split tasks
int32_t curr_thread_id = 0;
int64_t remaining_kv_len = kv_len_per_thread;
int32_t cum_split_num = 0;
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
const int32_t seq_len = input.seq_lens[req_id];
const int32_t q_token_num =
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
const int32_t q_start_pos = (casual ? (seq_len - q_token_num) : 0);
const int32_t kv_start_pos = 0;
const int32_t kv_end_pos = seq_len;
int32_t local_split_id = 0;
AttentionWorkItemGroup curr_workitem(req_id, 0, 0, seq_len);
for (int32_t token_id = 0; token_id < q_token_num;
token_id += max_num_q_token_per_iter) {
const int32_t q_tile_token_num =
std::min(max_num_q_token_per_iter, q_token_num - token_id);
const int32_t q_tile_pos_left = q_start_pos + token_id;
const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num;
const auto [kv_tile_pos_left, kv_tile_pos_right] = calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_pos_left, q_tile_pos_right,
left_sliding_window_size, right_sliding_window_size);
const auto [aligned_kv_tile_pos_left, aligned_kv_tile_pos_right] =
align_kv_tile_pos(kv_tile_pos_left, kv_tile_pos_right,
kv_len_alignment);
int32_t curr_kv_len =
aligned_kv_tile_pos_right - aligned_kv_tile_pos_left;
int32_t kv_token_pos_start = aligned_kv_tile_pos_left;
while (curr_kv_len > 0) {
if (curr_kv_len <= (remaining_kv_len + min_split_kv_len) ||
curr_thread_id == (thread_num - 1)) {
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += curr_kv_len;
remaining_kv_len -= curr_kv_len;
curr_kv_len = 0;
if (remaining_kv_len < 0) {
// stop to accept more workitems
remaining_kv_len -= min_split_kv_len;
}
if (curr_workitem.kv_split_pos_start != 0) {
// got a partial kv spilt, need to create a single workitem
curr_workitem.split_id = cum_split_num;
curr_workitem.local_split_id = local_split_id;
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
++reduce_workitems.back().split_num;
++cum_split_num;
curr_workitem = AttentionWorkItemGroup(
req_id, token_id + max_num_q_token_per_iter, 0, seq_len);
}
break;
}
if (remaining_kv_len < min_split_kv_len &&
(curr_workitem.total_kv_len > 0 ||
workitem_num_per_thread[curr_thread_id] > 0)) {
// remaining_kv_len is too short, and have allocated workitems, just
// leave to next thread
if (curr_workitem.total_kv_len > 0) {
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, 0, seq_len);
}
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
// retry this iteration
continue;
}
// only split tail splits with q_tile_token_num <=
// split_kv_q_token_num_threshold
if (token_id + max_num_q_token_per_iter < q_token_num ||
q_tile_token_num > split_kv_q_token_num_threshold) {
// if requires a new q tile iteration and already has workitems,
// leave this workitem to next thread
if (curr_workitem.q_token_num % default_tile_token_num == 0 &&
(curr_workitem.total_kv_len > 0 ||
workitem_num_per_thread[curr_thread_id] > 0)) {
if (curr_workitem.total_kv_len > 0) {
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, 0, seq_len);
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
}
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += curr_kv_len;
remaining_kv_len -= curr_kv_len;
curr_kv_len = 0;
break;
}
// split kv
if (curr_workitem.total_kv_len > 0) {
// write back curr workitem
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
if (kv_token_pos_start == aligned_kv_tile_pos_left) {
// first split, init the workitem
reduce_workitems.emplace_back(ReductionWorkItemGroup(
req_id, token_id, q_tile_token_num, cum_split_num));
}
int32_t spilt_size =
std::min(std::max(remaining_kv_len, (int64_t)min_split_kv_len),
(int64_t)curr_kv_len);
curr_workitem =
AttentionWorkItemGroup(req_id, token_id, kv_token_pos_start,
kv_token_pos_start + spilt_size);
curr_workitem.q_token_num += q_tile_token_num;
curr_workitem.total_kv_len += spilt_size;
curr_workitem.split_id = cum_split_num;
curr_workitem.local_split_id = local_split_id;
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
++reduce_workitems.back().split_num;
++cum_split_num;
++local_split_id;
kv_token_pos_start += spilt_size;
curr_kv_len -= spilt_size;
curr_workitem = AttentionWorkItemGroup(req_id, token_id,
kv_token_pos_start, seq_len);
// switch to next thread
++curr_thread_id;
remaining_kv_len = kv_len_per_thread;
}
}
if (curr_workitem.total_kv_len > 0) {
// write back curr workitem
workitems.emplace_back(curr_workitem);
++workitem_num_per_thread[curr_thread_id];
}
}
int64_t metadata_tensor_size =
sizeof(AttentionMetadata) +
workitems.size() * sizeof(AttentionWorkItemGroup) +
reduce_workitems.size() * sizeof(ReductionWorkItemGroup);
auto options =
torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
torch::Tensor metadata_tensor =
torch::empty({metadata_tensor_size}, options);
AttentionMetadata* metadata_ptr = new (metadata_tensor.data_ptr())
AttentionMetadata(input.isa, workitems.size(), reduce_workitems.size(),
cum_split_num, split_kv_q_token_num_threshold);
AttentionWorkItemGroup* workitem_groups_ptr =
metadata_ptr->workitem_groups_ptr;
ReductionWorkItemGroup* reduction_items_ptr =
metadata_ptr->reduction_items_ptr;
std::memcpy(workitem_groups_ptr, workitems.data(),
workitems.size() * sizeof(AttentionWorkItemGroup));
std::memcpy(reduction_items_ptr, reduce_workitems.data(),
reduce_workitems.size() * sizeof(ReductionWorkItemGroup));
int32_t effective_thread_num = 0;
for (; effective_thread_num < thread_num; ++effective_thread_num) {
if (workitem_num_per_thread[effective_thread_num] == 0) {
break;
}
}
std::memcpy(metadata_ptr->cu_workitem_num_per_thread + 1,
workitem_num_per_thread.data(),
workitem_num_per_thread.size() * sizeof(int32_t));
for (int32_t i = 1; i <= thread_num; ++i) {
metadata_ptr->cu_workitem_num_per_thread[i] +=
metadata_ptr->cu_workitem_num_per_thread[i - 1];
}
metadata_ptr->effective_thread_num = effective_thread_num;
{
// when q_tile_size = max_num_q_per_iter, requires max
// attention_scratchpad_size
AttentionScratchPad sc(0, *metadata_ptr, 0x0);
int64_t n = AttentionScheduler::calcu_tile_size_with_constant_q(
cache_size, input.head_dim, input.elem_size, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, kv_len_alignment, max_num_q_per_iter, true);
sc.update(input.head_dim, input.q_buffer_elem_size,
input.logits_buffer_elem_size, input.output_buffer_elem_size,
max_num_q_per_iter, max_num_q_per_iter, n);
metadata_ptr->attention_scratchpad_size_per_thread =
((sc.get_thread_scratchpad_size() + 63) / 64) * 64;
sc.update(0, metadata_ptr->reduction_split_num, input.head_dim,
q_head_per_kv * split_kv_q_token_num_threshold,
input.output_buffer_elem_size);
metadata_ptr->reduction_scratchpad_size_per_kv_head =
((sc.get_reduction_scratchpad_size() + 63) / 64) * 64;
}
int64_t scratchpad_size =
metadata_ptr->attention_scratchpad_size_per_thread *
metadata_ptr->thread_num +
metadata_ptr->reduction_scratchpad_size_per_kv_head *
(use_gqa ? input.num_heads_kv : input.num_heads_q);
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(
scratchpad_size);
// metadata_ptr->print();
// test out of boundary access
// {
// float* cache_ptr =
// cpu_utils::ScratchPadManager::getl_scratchpad_manager()->get_data<float>();
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
// }
// }
return metadata_tensor;
}
FORCE_INLINE static std::pair<int32_t, int32_t> calcu_kv_tile_pos(
int32_t kv_left_pos, int32_t kv_right_pos, int32_t q_left_pos,
int32_t q_right_pos, int32_t sliding_window_left,
int32_t sliding_window_right) {
if (sliding_window_left != -1) {
kv_left_pos = std::max(kv_left_pos, q_left_pos - sliding_window_left);
}
if (sliding_window_right != -1) {
kv_right_pos = std::min(kv_right_pos, q_right_pos + sliding_window_right);
}
return {kv_left_pos, kv_right_pos};
}
FORCE_INLINE static std::pair<int32_t, int32_t> align_kv_tile_pos(
int32_t kv_left_pos, int32_t kv_right_pos, int32_t align_factor) {
kv_left_pos = (kv_left_pos / align_factor) * align_factor;
kv_right_pos =
((kv_right_pos + align_factor - 1) / align_factor) * align_factor;
return {kv_left_pos, kv_right_pos};
}
static int64_t calcu_default_tile_size(int64_t cache_size, int64_t head_dim,
int64_t elem_size,
int64_t q_buffer_elem_size,
int64_t logits_buffer_elem_size,
int64_t output_buffer_elem_size,
int64_t max_num_q_per_iter,
int64_t round_size) {
// For CPU, different from CUDA, Q@K^T results should also be hold in cache,
// using float32. Intermediate outputs should be float32 to be compatible
// with AMX Then the cache includes:
// - Q: q_tile_size * head_dim * q_buffer_elem_size
// - K, V: 2 * k_tile_size * head_dim * elem_size
// - Q@K^T: max_num_q_per_iter * k_tile_size * logits_buffer_elem_size
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size
// By default, let tile_size = q_tile_size = k_tile_size. To record
// is_first_iter states in a static array, require the default tile <= 128 *
// max_num_q_per_iter
int64_t tile_size =
cache_size / (head_dim * (q_buffer_elem_size + 2 * elem_size +
output_buffer_elem_size) +
max_num_q_per_iter * logits_buffer_elem_size);
tile_size = std::min(tile_size, MaxQTileIterNum * max_num_q_per_iter);
int64_t rounded_tile_size = (tile_size / round_size) * round_size;
return std::max(rounded_tile_size, round_size);
}
static int64_t calcu_tile_size_with_constant_q(
int64_t cache_size, int64_t head_dim, int64_t elem_size,
int64_t q_buffer_elem_size, int64_t logits_buffer_elem_size,
int64_t output_buffer_elem_size, int64_t max_num_q_per_iter,
int64_t round_size, int64_t q_tile_size, bool one_round) {
// calculate tile_size with known q_tile_size
// If one_round is True, the outer Q tile loop time is 1, then the K,V will
// not be included in the cache
int64_t tile_size;
if (one_round) {
tile_size =
(cache_size - q_tile_size * head_dim *
(q_buffer_elem_size + output_buffer_elem_size)) /
(logits_buffer_elem_size * max_num_q_per_iter);
} else {
tile_size =
(cache_size - q_tile_size * head_dim *
(q_buffer_elem_size + output_buffer_elem_size)) /
(logits_buffer_elem_size * max_num_q_per_iter +
2 * head_dim * elem_size);
}
int64_t rounded_tile_size = (tile_size / round_size) * round_size;
return std::max(rounded_tile_size, round_size);
}
private:
int64_t available_cache_size_;
};
struct AttentionInput {
AttentionMetadata* metadata;
int32_t num_tokens;
int32_t num_heads;
int32_t num_kv_heads;
int32_t block_size;
void* query;
int64_t query_num_tokens_stride;
int64_t query_num_heads_stride;
int64_t cache_num_blocks_stride;
int64_t cache_num_kv_heads_stride;
int64_t blt_num_tokens_stride;
void* key_cache;
void* value_cache;
void* output;
int32_t* query_start_loc;
int32_t* seq_lens;
int32_t* block_table;
float* alibi_slopes;
c10::BFloat16* s_aux;
float scale;
bool causal;
int32_t sliding_window_left;
int32_t sliding_window_right;
float softcap;
};
#define DEFINE_CPU_ATTENTION_PARAMS \
q_buffer_t *__restrict__ q_heads_buffer, \
kv_cache_t *__restrict__ k_head_cache_ptr, \
kv_cache_t *__restrict__ v_head_cache_ptr, \
logits_buffer_t *__restrict__ logits_buffer, \
float *__restrict__ partial_q_buffer, float *__restrict__ max_buffer, \
float *__restrict__ sum_buffer, int32_t *__restrict__ block_table, \
const int32_t kv_tile_start_pos, const int32_t kv_tile_end_pos, \
const int32_t kv_tile_token_num, \
const int64_t kv_cache_num_blocks_stride, const int32_t q_head_num, \
const int32_t q_token_num, const int32_t q_tile_start_pos, \
const int32_t q_heads_per_kv, const int32_t block_size, \
const int32_t left_window_size, const int32_t right_window_size, \
float scale, const float softcap_scale, \
const float *__restrict__ alibi_slopes, const bool is_first_iter, \
const bool use_sink, const bool debug_info
#define CPU_ATTENTION_PARAMS \
q_heads_buffer, k_head_cache_ptr, v_head_cache_ptr, logits_buffer, \
partial_q_buffer, max_buffer, sum_buffer, block_table, \
kv_tile_start_pos, kv_tile_end_pos, kv_tile_token_num, \
kv_cache_num_blocks_stride, q_head_num, q_token_num, q_tile_start_pos, \
q_heads_per_kv, block_size, left_window_size, right_window_size, scale, \
softcap_scale, alibi_slopes, is_first_iter, use_sink, debug_info
enum class AttentionGemmPhase { QK, PV };
template <typename T>
struct VecTypeTrait {
using vec_t = void;
};
template <>
struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#if !defined(__powerpc__)
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;
};
#endif
template <typename T>
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
template <typename attention_impl_t>
class AttentionMainLoop {
public:
using query_t = typename attention_impl_t::query_t;
using q_buffer_t = typename attention_impl_t::q_buffer_t;
using kv_cache_t = typename attention_impl_t::kv_cache_t;
using logits_buffer_t = typename attention_impl_t::logits_buffer_t;
using partial_output_buffer_t =
typename attention_impl_t::partial_output_buffer_t;
using prob_buffer_t = typename attention_impl_t::prob_buffer_t;
static constexpr int64_t max_q_head_num_per_iter =
attention_impl_t::MaxQHeadNumPerIteration;
static constexpr int64_t blocksize_alignment =
attention_impl_t::BlockSizeAlignment;
static constexpr int64_t headdim_alignment =
attention_impl_t::HeadDimAlignment;
static constexpr int64_t head_dim = attention_impl_t::HeadDim;
static constexpr ISA ISAType = attention_impl_t::ISAType;
static constexpr bool scale_on_logits =
attention_impl_t::scale_on_logits; // apply scale on logits, otherwise
// apply scale on q_buffer
template <typename tile_gemm_t>
class Attention {
public:
// Args:
// - q_heads_buffer: [MaxQHeadNumPerIteration, head_dim]
// - k_head_cache_ptr: [num_blocks, block_size * head_dim]
// - v_head_cache_ptr: [num_blocks, block_size * head_dim]
// - logits_buffer: [MaxQHeadNumPerIteration, kv_tile_token_num], store Q@K
// - logits partial_q_buffer: [MaxQHeadNumPerIteration, head_dim], store
// partial output
// - max_buffer: [MaxQHeadNumPerIteration, 1], store max logits
// - sum_buffer: [MaxQHeadNumPerIteration, 1], store sum of exp
// - block_table
// - kv_tile_start_pos: start position of KV cache, aligned to
// BlockSizeAlignment
// - kv_tile_end_pos: end position of KV cache, aligned to
// BlockSizeAlignment
// - kv_tile_token_num: KV token num, aligned to BlockSizeAlignment
// - kv_cache_num_blocks_stride
// - q_head_num: head num of q_tile
// - q_token_num: token num of q_tile, should be q_head_num /
// q_heads_per_kv
// - q_tile_start_pos: start pos of the first token in q_heads_buffer
// - q_heads_per_kv
// - block_size
// - left_window_size
// - right_window_size
// - scale
// - softcap_scale
// - alibi_slopes
// - is_first_iter
// - use_sink
// - debug_info
void operator()(DEFINE_CPU_ATTENTION_PARAMS) {
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
const int64_t k_cache_token_group_stride =
attention_impl_t::k_cache_token_group_stride(block_size);
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
const int64_t v_cache_token_group_stride =
attention_impl_t::v_cache_token_group_stride(block_size);
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
const int64_t v_cache_head_group_stride =
attention_impl_t::v_cache_head_group_stride(block_size);
const int32_t token_group_num = kv_tile_token_num / blocksize_alignment;
const int32_t token_group_num_per_block =
block_size / blocksize_alignment;
const int32_t start_block_idx = kv_tile_start_pos / block_size;
const int32_t start_block_offset = kv_tile_start_pos % block_size;
const int32_t start_block_group_offset =
start_block_offset / blocksize_alignment;
const int32_t end_block_idx =
(kv_tile_start_pos + kv_tile_token_num - 1) / block_size + 1;
// compute Q@K logits
{
int32_t curr_group_offset =
start_block_group_offset * k_cache_token_group_stride;
int32_t curr_group_num_in_block =
token_group_num_per_block - start_block_group_offset;
int32_t remaining_group_num = token_group_num;
logits_buffer_t* curr_logits_buffer = logits_buffer;
for (int32_t block_idx = start_block_idx; block_idx < end_block_idx;
++block_idx) {
int32_t physical_block_idx = block_table[block_idx];
kv_cache_t* k_cache_block_ptr =
k_head_cache_ptr +
physical_block_idx * kv_cache_num_blocks_stride +
curr_group_offset;
curr_group_num_in_block =
std::min(remaining_group_num, curr_group_num_in_block);
for (int32_t block_group_idx = 0;
block_group_idx < curr_group_num_in_block; ++block_group_idx) {
// logits_tile = q_tile @ k_tile, [MaxQHeadNumPerIteration,
// BlockSizeAlignment] = [MaxQHeadNumPerIteration, head_dim] @
// [head_dim, BlockSizeAlignment]
// By default, logits_buffer, q_buffer and k_cache are row-major,
// but may be packed by ISA implementation.
tile_gemm_t::template gemm<AttentionGemmPhase::QK, head_dim>(
q_head_num, q_heads_buffer, k_cache_block_ptr,
curr_logits_buffer, head_dim, block_size, kv_tile_token_num,
block_size, head_dim, false);
if constexpr (scale_on_logits) {
float* __restrict__ scale_curr_logits_buffer = curr_logits_buffer;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t i = 0; i < q_head_num; ++i) {
static_assert(blocksize_alignment % 16 == 0);
constexpr int32_t vec_num = blocksize_alignment / 16;
vec_op::unroll_loop<int32_t, vec_num>([&](int32_t vec_idx) {
vec_op::FP32Vec16 vec(scale_curr_logits_buffer +
vec_idx * 16);
vec = vec * scale_vec;
vec.save(scale_curr_logits_buffer + vec_idx * 16);
});
scale_curr_logits_buffer += kv_tile_token_num;
}
}
// Move buffer ptrs
k_cache_block_ptr += k_cache_token_group_stride;
curr_logits_buffer += blocksize_alignment;
}
// Update
remaining_group_num -= curr_group_num_in_block;
curr_group_offset = 0;
curr_group_num_in_block = token_group_num_per_block;
}
}
// process logits
{
// if (debug_info){
// print_logits("raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
// }
if (softcap_scale != 0.0f) {
apply_softcap(logits_buffer, kv_tile_token_num, q_head_num,
kv_tile_token_num, softcap_scale);
// print_logits("softcap raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
}
if (alibi_slopes != nullptr) {
apply_alibi_slopes(logits_buffer, alibi_slopes, kv_tile_token_num,
q_tile_start_pos, kv_tile_start_pos, q_token_num,
kv_tile_token_num, q_heads_per_kv);
// print_logits("alibi raw logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
}
apply_mask(logits_buffer, kv_tile_token_num, q_tile_start_pos,
kv_tile_start_pos, kv_tile_end_pos, q_token_num,
q_heads_per_kv, left_window_size, right_window_size);
// if (debug_info){
// print_logits("masked logits", logits_buffer, q_head_num,
// kv_tile_token_num, kv_tile_token_num);
// print_logits("old_max", max_buffer, 1, q_head_num, q_head_num);
// print_logits("old_sum", sum_buffer, 1, q_head_num, q_head_num);
// }
apply_softmax(logits_buffer, partial_q_buffer, max_buffer, sum_buffer,
kv_tile_token_num, q_head_num, kv_tile_token_num,
is_first_iter, use_sink);
// if (debug_info){
// print_logits("softmax logits",
// reinterpret_cast<prob_buffer_t*>(logits_buffer), q_head_num,
// kv_tile_token_num, kv_tile_token_num * sizeof(logits_buffer_t) /
// sizeof(prob_buffer_t));
// print_logits("new_max", max_buffer, 1, q_head_num, q_head_num);
// print_logits("new_sum", sum_buffer, 1, q_head_num, q_head_num);
// }
}
// compute P@V
{
int32_t curr_group_offset =
start_block_group_offset * v_cache_token_group_stride;
int32_t curr_group_num_in_block =
token_group_num_per_block - start_block_group_offset;
int32_t remaining_group_num = token_group_num;
int32_t head_dim_group_num = head_dim / headdim_alignment;
prob_buffer_t* curr_prob_buffer =
reinterpret_cast<prob_buffer_t*>(logits_buffer);
int64_t prob_buffer_stride =
kv_tile_token_num *
(sizeof(logits_buffer_t) / sizeof(prob_buffer_t));
partial_output_buffer_t* curr_partial_q_buffer = partial_q_buffer;
bool accum_c = !is_first_iter;
for (int32_t block_idx = start_block_idx; block_idx < end_block_idx;
++block_idx) {
int32_t physical_block_idx = block_table[block_idx];
kv_cache_t* v_cache_block_ptr =
v_head_cache_ptr +
physical_block_idx * kv_cache_num_blocks_stride +
curr_group_offset;
curr_group_num_in_block =
std::min(remaining_group_num, curr_group_num_in_block);
int32_t curr_token_num =
curr_group_num_in_block * blocksize_alignment;
for (int32_t head_dim_group_idx = 0;
head_dim_group_idx < head_dim_group_num; ++head_dim_group_idx) {
// output_tile = p_tile @ v_tile, [MaxQHeadNumPerIteration,
// HeadDimAlignment] = [MaxQHeadNumPerIteration, block_size] @
// [block_size, HeadDimAlignment]
tile_gemm_t::template gemm<AttentionGemmPhase::PV, -1>(
q_head_num, curr_prob_buffer, v_cache_block_ptr,
curr_partial_q_buffer, prob_buffer_stride, head_dim, head_dim,
block_size, curr_token_num, accum_c);
// Update
curr_partial_q_buffer += headdim_alignment;
v_cache_block_ptr += v_cache_head_group_stride;
}
// Update
remaining_group_num -= curr_group_num_in_block;
curr_group_offset = 0;
curr_group_num_in_block = token_group_num_per_block;
curr_prob_buffer += curr_token_num;
curr_partial_q_buffer = partial_q_buffer;
accum_c = true;
}
}
// if (debug_info) {
// print_logits("output", partial_q_buffer, q_head_num, head_dim,
// head_dim);
// }
}
void apply_mask(logits_buffer_t* __restrict__ logits_buffer,
const int64_t logits_buffer_stride,
const int32_t q_tile_start_pos,
const int32_t kv_tile_start_pos,
const int32_t kv_tile_end_pos, const int32_t q_token_num,
const int32_t q_heads_per_kv,
const int32_t sliding_window_left,
const int32_t sliding_window_right) {
// Apply mask
constexpr logits_buffer_t neg_inf =
-std::numeric_limits<logits_buffer_t>::infinity();
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
int32_t curr_token_pos = q_tile_start_pos;
for (int32_t token_idx = 0; token_idx < q_token_num; ++token_idx) {
int32_t left_kv_pos = [&]() {
int32_t pos = kv_tile_start_pos;
if (sliding_window_left != -1) {
pos = std::max(pos, curr_token_pos - sliding_window_left);
}
// Clamp to tile end to avoid OOB when window starts past the tile
return std::min(pos, kv_tile_end_pos);
}();
int32_t right_kv_pos = [&]() {
int32_t pos = kv_tile_end_pos;
if (sliding_window_right != -1) {
pos = std::min(pos,
std::max(kv_tile_start_pos,
curr_token_pos + sliding_window_right + 1));
}
return pos;
}();
int32_t left_invalid_token_num = left_kv_pos - kv_tile_start_pos;
int32_t right_invalid_token_num = kv_tile_end_pos - right_kv_pos;
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
logits_buffer_t* __restrict__ curr_logits_buffer_tail =
curr_logits_buffer + right_kv_pos - kv_tile_start_pos;
for (int32_t i = 0; i < left_invalid_token_num; ++i) {
curr_logits_buffer[i] = neg_inf;
}
for (int32_t i = 0; i < right_invalid_token_num; ++i) {
curr_logits_buffer_tail[i] = neg_inf;
}
curr_logits_buffer += logits_buffer_stride;
}
++curr_token_pos;
}
}
void apply_softmax(logits_buffer_t* __restrict__ logits_buffer,
float* __restrict__ partial_q_buffer,
float* __restrict__ max_buffer,
float* __restrict__ sum_buffer,
const int64_t logits_buffer_stride, int32_t q_head_num,
int32_t kv_tile_token_num, bool is_first_iter,
bool use_sink) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
using prob_buffer_vec_t = typename VecTypeTrait<prob_buffer_t>::vec_t;
static_assert(sizeof(prob_buffer_t) <= sizeof(logits_buffer_t));
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
float* __restrict__ curr_partial_q_buffer = partial_q_buffer;
const int32_t vec_num = kv_tile_token_num / 16;
const int32_t head_vec_num = head_dim / 16;
for (int32_t i = 0; i < q_head_num; ++i) {
float init_max_val = max_buffer[i];
float init_sum_val = sum_buffer[i];
// apply scale and compute max
vec_op::FP32Vec16 max_vec(init_max_val);
{
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
max_vec = vec.max(max_vec);
curr_logits_buffer_iter += 16;
}
}
float new_max_val = max_vec.reduce_max();
float rescale_factor = init_max_val - new_max_val;
// use same rescale threshold with FA4.
// https://github.com/Dao-AILab/flash-attention/blob/1b8e1e641c6a179be9a0538b7f40fd595050b735/flash_attn/cute/flash_fwd_sm100.py#L1271
bool need_rescale = rescale_factor < -8.0;
if (!need_rescale) {
new_max_val = init_max_val;
} else {
max_buffer[i] = new_max_val;
}
// sub max, compute exp and sum
max_vec = vec_op::FP32Vec16(new_max_val);
vec_op::FP32Vec16 sum_vec(0.0);
{
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
prob_buffer_t* __restrict__ curr_prob_buffer_iter =
reinterpret_cast<prob_buffer_t*>(curr_logits_buffer);
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec - max_vec;
// compute exp
#ifdef DEFINE_FAST_EXP
vec = fast_exp(vec);
prob_buffer_vec_t output_vec(vec);
output_vec.save(curr_prob_buffer_iter);
#else
vec.save(curr_logits_buffer_iter);
for (int32_t k = 0; k < 16; ++k) {
curr_logits_buffer_iter[k] = std::exp(curr_logits_buffer_iter[k]);
}
vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
#endif
sum_vec = sum_vec + vec;
curr_logits_buffer_iter += 16;
curr_prob_buffer_iter += 16;
}
}
float new_sum_val = sum_vec.reduce_sum();
// rescale sum and partial outputs
if (need_rescale) {
// compute rescale factor
rescale_factor = std::exp(rescale_factor);
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
// rescale sum
new_sum_val += rescale_factor * init_sum_val;
// rescale output
if (!is_first_iter) {
float* __restrict__ curr_partial_q_buffer_iter =
curr_partial_q_buffer;
for (int32_t j = 0; j < head_vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_partial_q_buffer_iter);
vec = vec * rescale_factor_vec;
vec.save(curr_partial_q_buffer_iter);
curr_partial_q_buffer_iter += 16;
}
}
} else {
new_sum_val += init_sum_val;
}
sum_buffer[i] = new_sum_val;
curr_logits_buffer += logits_buffer_stride;
curr_partial_q_buffer += head_dim;
}
}
void apply_softcap(logits_buffer_t* __restrict__ logits_buffer,
const int64_t logits_buffer_stride, int32_t q_head_num,
int32_t kv_tile_token_num, float softcap_scale) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
float inv_softcap_scale = 1.0 / softcap_scale;
vec_op::FP32Vec16 softcap_scale_vec(softcap_scale);
vec_op::FP32Vec16 inv_softcap_scale_vec(inv_softcap_scale);
vec_op::FP32Vec16 ones_vec(1.0);
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
const int32_t vec_num = kv_tile_token_num / 16;
for (int32_t i = 0; i < q_head_num; ++i) {
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t j = 0; j < vec_num; ++j) {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec * inv_softcap_scale_vec;
#ifdef DEFINE_FAST_EXP
vec = fast_exp(vec);
vec_op::FP32Vec16 inv_vec = ones_vec / vec;
vec = (vec - inv_vec) / (vec + inv_vec);
#else
vec.save(curr_logits_buffer_iter);
for (int k = 0; k < 16; ++k) {
curr_logits_buffer_iter[k] = std::tanh(curr_logits_buffer_iter[k]);
}
vec = vec_op::FP32Vec16(curr_logits_buffer_iter);
#endif
vec = vec * softcap_scale_vec;
vec.save(curr_logits_buffer_iter);
curr_logits_buffer_iter += 16;
}
curr_logits_buffer += logits_buffer_stride;
}
}
void apply_alibi_slopes(logits_buffer_t* __restrict__ logits_buffer,
const float* __restrict__ alibi_slopes,
const int64_t logits_buffer_stride,
const int32_t q_tile_start_pos,
const int32_t kv_tile_start_pos,
const int32_t q_token_num,
const int32_t kv_tile_token_num,
const int32_t q_heads_per_kv) {
alignas(64) constexpr float initial_arange_vals[16] = {
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
const int32_t vec_num = kv_tile_token_num / 16;
vec_op::FP32Vec16 initial_arange_vals_vec(initial_arange_vals);
initial_arange_vals_vec =
initial_arange_vals_vec + vec_op::FP32Vec16((float)kv_tile_start_pos);
vec_op::FP32Vec16 pos_offset_vec(16.0);
logits_buffer_t* __restrict__ curr_logits_buffer = logits_buffer;
for (int32_t i = 0; i < q_token_num; ++i) {
vec_op::FP32Vec16 curr_q_pos_vec((float)(i + q_tile_start_pos));
for (int32_t j = 0; j < q_heads_per_kv; ++j) {
vec_op::FP32Vec16 alibi_scale_vec(alibi_slopes[j]);
vec_op::FP32Vec16 curr_kv_pos_vec(initial_arange_vals_vec);
logits_buffer_t* __restrict__ curr_logits_buffer_iter =
curr_logits_buffer;
for (int32_t k = 0; k < vec_num; ++k) {
vec_op::FP32Vec16 alibi_bias_vec =
alibi_scale_vec * (curr_kv_pos_vec - curr_q_pos_vec);
vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec + alibi_bias_vec;
vec.save(curr_logits_buffer_iter);
curr_kv_pos_vec = curr_kv_pos_vec + pos_offset_vec;
curr_logits_buffer_iter += 16;
}
curr_logits_buffer += logits_buffer_stride;
}
}
}
};
public:
void operator()(const AttentionInput* input) {
const int thread_num = omp_get_max_threads();
TORCH_CHECK_EQ(input->metadata->thread_num, thread_num);
std::atomic<int32_t> guard_counter(0);
std::atomic<int32_t>* guard_counter_ptr = &guard_counter;
#pragma omp parallel for schedule(static, 1)
for (int thread_id = 0; thread_id < thread_num; ++thread_id) {
AttentionMetadata& metadata = *input->metadata;
if (metadata.workitem_group_num == 0) {
continue;
}
attention_impl_t attn_impl;
// general information
const int32_t q_head_num = input->num_heads;
const int32_t kv_head_num = input->num_kv_heads;
const int32_t q_heads_per_kv = q_head_num / kv_head_num;
const bool use_gqa =
(max_q_head_num_per_iter % q_heads_per_kv == 0) ? true : false;
const int32_t actual_kv_head_num = use_gqa ? kv_head_num : q_head_num;
const int32_t actual_q_heads_per_kv = use_gqa ? q_heads_per_kv : 1;
TORCH_CHECK_LE(actual_q_heads_per_kv, max_q_head_num_per_iter);
const int32_t max_q_token_num_per_iter =
max_q_head_num_per_iter / actual_q_heads_per_kv;
const int64_t q_token_num_stride = input->query_num_tokens_stride;
const int64_t q_head_num_stride = input->query_num_heads_stride;
const int64_t kv_cache_head_num_stride = input->cache_num_kv_heads_stride;
const int64_t kv_cache_block_num_stride = input->cache_num_blocks_stride;
const int32_t sliding_window_left = input->sliding_window_left;
const int32_t sliding_window_right = input->sliding_window_right;
const int32_t block_size = input->block_size;
const float scale = input->scale;
const float softcap_scale = input->softcap;
const float* alibi_slopes = input->alibi_slopes;
const c10::BFloat16* s_aux = input->s_aux;
const bool casual = input->causal;
int32_t* const block_table = input->block_table;
const int64_t block_table_stride = input->blt_num_tokens_stride;
// init buffers
void* scratchpad_ptr =
cpu_utils::ScratchPadManager::get_scratchpad_manager()
->get_data<void>();
AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr);
const int32_t total_reduction_split_num = metadata.reduction_split_num;
if (metadata.reduction_split_num > 0) {
// reset split flag
for (int32_t head_idx = thread_id; head_idx < actual_kv_head_num;
head_idx += thread_num) {
buffer_manager.update(head_idx, total_reduction_split_num, head_dim,
0, sizeof(partial_output_buffer_t));
volatile bool* __restrict__ curr_flag_ptr =
buffer_manager.get_reduce_flag_buffer();
for (int32_t split_idx = 0; split_idx < total_reduction_split_num;
++split_idx) {
curr_flag_ptr[split_idx] = false;
}
}
}
const int64_t available_cache_size = cpu_utils::get_available_l2_size();
const int32_t default_tile_size =
AttentionScheduler::calcu_default_tile_size(
available_cache_size, head_dim, sizeof(kv_cache_t),
sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
max_q_head_num_per_iter);
const int32_t default_q_tile_token_num =
default_tile_size / actual_q_heads_per_kv;
AttentionWorkItemGroup* const workitem_groups =
metadata.workitem_groups_ptr;
const int32_t* cu_workitem_num_per_thread =
metadata.cu_workitem_num_per_thread;
ReductionWorkItemGroup* const reduction_items =
metadata.reduction_items_ptr;
const int32_t effective_thread_num = metadata.effective_thread_num;
const int32_t reduction_item_num = metadata.reduction_item_num;
const int32_t split_kv_q_token_num_threshold =
metadata.split_kv_q_token_num_threshold;
const int32_t workitem_groups_counter_num =
actual_kv_head_num * effective_thread_num;
const int32_t reduction_items_counter_num =
actual_kv_head_num * reduction_item_num;
const int32_t total_counter_num =
workitem_groups_counter_num + reduction_items_counter_num;
if (metadata.reduction_split_num > 0) {
++(*guard_counter_ptr);
while (guard_counter_ptr->load() != thread_num) {
#ifdef FAST_SPINNING
FAST_SPINNING
#else
std::this_thread::yield();
#endif
}
}
// main loop
for (;;) {
int64_t task_idx = metadata.acquire_counter();
if (task_idx >= total_counter_num) {
// no more tasks, leave loop
break;
}
if (task_idx < workitem_groups_counter_num) {
// attention task
// map task_idx to workitem_groups
const int32_t kv_head_idx = task_idx / effective_thread_num;
const int32_t thread_offset = task_idx % effective_thread_num;
AttentionWorkItemGroup* const curr_workitem_groups =
workitem_groups + cu_workitem_num_per_thread[thread_offset];
const int32_t curr_workitem_groups_num =
cu_workitem_num_per_thread[thread_offset + 1] -
cu_workitem_num_per_thread[thread_offset];
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
for (int32_t workitem_group_idx = 0;
workitem_group_idx < curr_workitem_groups_num;
++workitem_group_idx) {
AttentionWorkItemGroup* const current_workitem_group =
&curr_workitem_groups[workitem_group_idx];
const int32_t current_group_idx = current_workitem_group->req_id;
const int32_t kv_start_pos =
current_workitem_group->kv_split_pos_start;
const int32_t kv_end_pos = current_workitem_group->kv_split_pos_end;
const int32_t curr_spilt_id = current_workitem_group->split_id;
const int32_t q_token_id_start =
current_workitem_group->q_token_id_start;
const int32_t q_token_num = current_workitem_group->q_token_num;
// taskgroup general information
const int32_t q_end = input->query_start_loc[current_group_idx + 1];
const int32_t q_start = input->query_start_loc[current_group_idx];
const int32_t seq_len = input->seq_lens[current_group_idx];
const int32_t q_start_pos =
(casual ? seq_len - (q_end - q_start) : 0);
const int32_t block_num = (seq_len + block_size - 1) / block_size;
// Only apply sink for the first KV split
bool use_sink = (s_aux != nullptr &&
current_workitem_group->local_split_id == 0);
for (int32_t q_token_offset = 0; q_token_offset < q_token_num;
q_token_offset += default_q_tile_token_num) {
bool first_iter_flag[AttentionScheduler::MaxQTileIterNum];
for (int32_t i = 0; i < AttentionScheduler::MaxQTileIterNum;
++i) {
first_iter_flag[i] = true;
}
const int32_t q_token_start_idx =
q_start + q_token_offset + q_token_id_start;
const int32_t actual_q_token_num = std::min(
default_q_tile_token_num, q_token_num - q_token_offset);
const int32_t q_head_tile_size =
actual_q_token_num * actual_q_heads_per_kv;
const int32_t rounded_q_head_tile_size =
((q_head_tile_size + max_q_head_num_per_iter - 1) /
max_q_head_num_per_iter) *
max_q_head_num_per_iter;
const int32_t kv_tile_size =
AttentionScheduler::calcu_tile_size_with_constant_q(
available_cache_size, head_dim, sizeof(kv_cache_t),
sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
blocksize_alignment, rounded_q_head_tile_size,
rounded_q_head_tile_size <= max_q_head_num_per_iter);
// update buffers
buffer_manager.update(
head_dim, sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
rounded_q_head_tile_size, kv_tile_size);
q_buffer_t* q_buffer = buffer_manager.get_q_buffer<q_buffer_t>();
float* logits_buffer = buffer_manager.get_logits_buffer();
float* partial_q_buffer = buffer_manager.get_output_buffer();
float* max_buffer = buffer_manager.get_max_buffer();
float* sum_buffer = buffer_manager.get_sum_buffer();
const int32_t q_tile_start_pos =
q_start_pos + q_token_offset + q_token_id_start;
const int32_t q_tile_end_pos =
q_tile_start_pos + actual_q_token_num;
const auto [kv_tile_start_pos, kv_tile_end_pos] =
AttentionScheduler::calcu_kv_tile_pos(
kv_start_pos, kv_end_pos, q_tile_start_pos,
q_tile_end_pos, sliding_window_left,
sliding_window_right);
const auto [rounded_kv_tile_start_pos, rounded_kv_tile_end_pos] =
AttentionScheduler::align_kv_tile_pos(
kv_tile_start_pos, kv_tile_end_pos, blocksize_alignment);
int32_t curr_kv_head_idx =
use_gqa ? kv_head_idx
: (kv_head_idx /
q_heads_per_kv); // for GQA disabled case
// std::printf("thread_id: %d, req_id: %d, q_token_start: %d,
// q_token_end: %d, q_head_start: %d, q_head_end: %d, kv_head_idx:
// %d, kv_pos_start: %d, kv_pos_end: %d\n",
// thread_id, current_group_idx,
// q_token_start_idx, q_token_start_idx +
// actual_q_token_num, q_head_start_idx,
// q_head_start_idx + actual_q_heads_per_kv,
// curr_kv_head_idx, kv_tile_start_pos,
// kv_tile_end_pos);
// move buffers
kv_cache_t* curr_k_cache =
reinterpret_cast<kv_cache_t*>(input->key_cache) +
curr_kv_head_idx * kv_cache_head_num_stride;
kv_cache_t* curr_v_cache =
reinterpret_cast<kv_cache_t*>(input->value_cache) +
curr_kv_head_idx * kv_cache_head_num_stride;
query_t* const q_tile_ptr =
reinterpret_cast<query_t*>(input->query) +
q_token_start_idx * q_token_num_stride +
q_head_start_idx * q_head_num_stride;
size_t output_buffer_offset =
q_token_start_idx * q_head_num * head_dim +
q_head_start_idx * head_dim;
int32_t* curr_block_table =
block_table + current_group_idx * block_table_stride;
const float* curr_alibi_slopes =
(alibi_slopes != nullptr ? alibi_slopes + q_head_start_idx
: nullptr);
const c10::BFloat16* curr_s_aux =
(s_aux != nullptr ? s_aux + q_head_start_idx : nullptr);
// copy the Q tile to q_buffer, the logical layout of q_buffer is
// [actual_q_token_num, actual_q_heads_per_kv, head_dim]
{
attn_impl.copy_q_heads_tile(
q_tile_ptr, q_buffer, actual_q_token_num,
actual_q_heads_per_kv, q_token_num_stride,
q_head_num_stride, scale);
}
if (use_sink) {
alignas(64) float s_aux_fp32[16];
// All other platforms have BF16Vec16 available
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
vec_op::FP32Vec16 vec_fp32(vec_bf16);
vec_fp32.save(s_aux_fp32);
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
++token_idx) {
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
++head_idx) {
curr_sum_buffer[head_idx] = 1.0f;
curr_max_buffer[head_idx] = s_aux_fp32[head_idx];
}
curr_sum_buffer += actual_q_heads_per_kv;
curr_max_buffer += actual_q_heads_per_kv;
}
} else {
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
++token_idx) {
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
++head_idx) {
curr_sum_buffer[head_idx] = 0.0f;
curr_max_buffer[head_idx] =
std::numeric_limits<float>::lowest();
}
curr_sum_buffer += actual_q_heads_per_kv;
curr_max_buffer += actual_q_heads_per_kv;
}
}
// compute loop
for (int32_t kv_tile_pos = rounded_kv_tile_start_pos;
kv_tile_pos < rounded_kv_tile_end_pos;
kv_tile_pos += kv_tile_size) {
const int32_t kv_tile_pos_left = kv_tile_pos;
const int32_t kv_tile_pos_right = std::min(
kv_tile_pos_left + kv_tile_size, rounded_kv_tile_end_pos);
for (int32_t q_head_tile_token_offset = 0;
q_head_tile_token_offset < actual_q_token_num;
q_head_tile_token_offset += max_q_token_num_per_iter) {
const int32_t q_tile_pos_left =
q_tile_start_pos + q_head_tile_token_offset;
const int32_t q_tile_token_num =
std::min(max_q_token_num_per_iter,
actual_q_token_num - q_head_tile_token_offset);
const int32_t q_tile_head_offset =
q_head_tile_token_offset * actual_q_heads_per_kv;
const int32_t q_tile_head_num =
q_tile_token_num * actual_q_heads_per_kv;
const int32_t q_tile_pos_right =
q_tile_pos_left + q_tile_token_num;
const auto [actual_kv_tile_pos_left,
actual_kv_tile_pos_right] =
AttentionScheduler::calcu_kv_tile_pos(
kv_tile_pos_left, kv_tile_pos_right, q_tile_pos_left,
q_tile_pos_right, sliding_window_left,
sliding_window_right);
const int32_t q_iter_idx =
q_head_tile_token_offset / max_q_token_num_per_iter;
if (actual_kv_tile_pos_right <= actual_kv_tile_pos_left) {
continue;
}
// align kv_pos to blocksize_alignment
const auto [aligned_actual_kv_tile_pos_left,
aligned_actual_kv_tile_pos_right] =
AttentionScheduler::align_kv_tile_pos(
actual_kv_tile_pos_left, actual_kv_tile_pos_right,
blocksize_alignment);
const int32_t actual_kv_token_num =
aligned_actual_kv_tile_pos_right -
aligned_actual_kv_tile_pos_left;
// std::printf("\tq_iter_idx: %d, q_token_start: %d,
// q_token_end: %d, q_token_num: %d, q_head_num: %d,
// q_pos_start: %d, q_pos_end: %d, kv_pos_start: %d,
// kv_pos_end: %d\n",
// q_iter_idx, q_token_start_idx +
// q_head_tile_token_offset, q_token_start_idx +
// q_head_tile_token_offset + q_tile_token_num,
// q_tile_token_num, q_tile_head_num,
// q_tile_pos_left, q_tile_pos_right,
// aligned_actual_kv_tile_pos_left,
// aligned_actual_kv_tile_pos_right);
// Move buffers
q_buffer_t* curr_q_heads_buffer =
q_buffer + q_tile_head_offset * head_dim;
float* curr_partial_q_buffer =
partial_q_buffer + q_tile_head_offset * head_dim;
float* curr_max_buffer = max_buffer + q_tile_head_offset;
float* curr_sum_buffer = sum_buffer + q_tile_head_offset;
bool debug_info = false;
// bool debug_info = (
// q_head_start_idx == 4 &&
// (q_token_start_idx + q_head_tile_token_offset) <=
// 4
// && (q_token_start_idx + q_head_tile_token_offset +
// q_tile_token_num) > 4
// );
// if (debug_info) {
// std::printf("\tq_iter_idx: %d, q_token_start: %d,"
// "q_token_end: %d, q_token_num: %d, q_head_num: %d,"
// "q_pos_start: %d, q_pos_end: %d, kv_pos_start: %d,"
// "kv_pos_end: %d\n",
// q_iter_idx, q_token_start_idx +
// q_head_tile_token_offset, q_token_start_idx
// + q_head_tile_token_offset +
// q_tile_token_num, q_tile_token_num,
// q_tile_head_num, q_tile_pos_left,
// q_tile_pos_right,
// aligned_actual_kv_tile_pos_left,
// aligned_actual_kv_tile_pos_right);
// }
attn_impl.template execute_attention<Attention>(
curr_q_heads_buffer, curr_k_cache, curr_v_cache,
logits_buffer, curr_partial_q_buffer, curr_max_buffer,
curr_sum_buffer, curr_block_table,
aligned_actual_kv_tile_pos_left,
aligned_actual_kv_tile_pos_right, actual_kv_token_num,
kv_cache_block_num_stride, q_tile_head_num,
q_tile_token_num, q_tile_pos_left, actual_q_heads_per_kv,
block_size, sliding_window_left, sliding_window_right,
scale, softcap_scale, curr_alibi_slopes,
first_iter_flag[q_iter_idx], use_sink, debug_info);
first_iter_flag[q_iter_idx] = false;
}
}
// write back partial results to output buffer or reduction buffer
{
if (curr_spilt_id == -1) {
final_output(partial_q_buffer,
reinterpret_cast<query_t*>(input->output) +
output_buffer_offset,
sum_buffer, actual_q_heads_per_kv,
actual_q_token_num, q_head_num);
} else {
const int32_t stride =
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
buffer_manager.update(kv_head_idx, total_reduction_split_num,
head_dim, stride, sizeof(float));
volatile bool* split_flag_buffer =
buffer_manager.get_reduce_flag_buffer() + curr_spilt_id;
float* split_output_buffer =
buffer_manager.get_reduce_output_buffer() +
curr_spilt_id * stride * head_dim;
float* split_max_buffer =
buffer_manager.get_reduce_max_buffer() +
curr_spilt_id * stride;
float* split_sum_buffer =
buffer_manager.get_reduce_sum_buffer() +
curr_spilt_id * stride;
partial_output(partial_q_buffer, max_buffer, sum_buffer,
q_head_tile_size, split_output_buffer,
split_max_buffer, split_sum_buffer,
split_flag_buffer);
}
}
}
}
} else {
task_idx -= workitem_groups_counter_num;
const int32_t kv_head_idx = task_idx / reduction_item_num;
const int32_t item_offset = task_idx % reduction_item_num;
ReductionWorkItemGroup* const curr_workitem_groups =
reduction_items + item_offset;
const int32_t curr_output_token_idx =
curr_workitem_groups->q_token_id_start;
const int32_t curr_output_token_num =
curr_workitem_groups->q_token_id_num;
const int32_t curr_split_id = curr_workitem_groups->split_start_id;
const int32_t curr_split_num = curr_workitem_groups->split_num;
const int32_t current_group_idx = curr_workitem_groups->req_id;
const int32_t curr_output_head_num =
curr_output_token_num * actual_q_heads_per_kv;
const int32_t q_start = input->query_start_loc[current_group_idx];
const int32_t q_token_start_idx = q_start + curr_output_token_idx;
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
size_t output_buffer_offset =
q_token_start_idx * q_head_num * head_dim +
q_head_start_idx * head_dim;
const int32_t stride =
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
buffer_manager.update(kv_head_idx, total_reduction_split_num,
head_dim, stride, sizeof(float));
volatile bool* split_flag_buffer =
buffer_manager.get_reduce_flag_buffer() + curr_split_id;
float* split_output_buffer =
buffer_manager.get_reduce_output_buffer() +
curr_split_id * stride * head_dim;
float* split_max_buffer =
buffer_manager.get_reduce_max_buffer() + curr_split_id * stride;
float* split_sum_buffer =
buffer_manager.get_reduce_sum_buffer() + curr_split_id * stride;
reduce_splits(split_output_buffer, split_max_buffer, split_sum_buffer,
split_flag_buffer, stride, curr_output_head_num,
curr_split_num);
final_output(
split_output_buffer,
reinterpret_cast<query_t*>(input->output) + output_buffer_offset,
split_sum_buffer, actual_q_heads_per_kv, curr_output_token_num,
q_head_num);
}
}
}
// Reset counter for next call
input->metadata->reset_counter();
}
void reduce_splits(float* __restrict__ split_output_buffer,
float* __restrict__ split_max_buffer,
float* __restrict__ split_sum_buffer,
volatile bool* __restrict__ flags,
const int32_t head_num_per_split,
const int32_t curr_head_num, const int32_t split_num) {
#ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP
#endif
// restrict curr_head_num <= 16 in the scheduler
// elems in split_max_buffer, split_sum_buffer are not cache alignment, use
// local buffers to reduce false-sharing
alignas(64) float local_max[16];
alignas(64) float local_sum[16];
float* __restrict__ curr_split_output_buffer = split_output_buffer;
float* __restrict__ curr_split_max_buffer = split_max_buffer;
float* __restrict__ curr_split_sum_buffer = split_sum_buffer;
constexpr int32_t head_dim_group_num = head_dim / 16;
for (int32_t split_idx = 0; split_idx < split_num; ++split_idx) {
while (!flags[split_idx]) {
#ifdef FAST_SPINNING
FAST_SPINNING
#else
std::this_thread::yield();
#endif
}
std::atomic_thread_fence(std::memory_order_acquire);
if (split_idx > 0) {
float* __restrict__ curr_output_buffer = split_output_buffer;
float* __restrict__ curr_split_output_buffer_iter =
curr_split_output_buffer;
for (int32_t head_idx = 0; head_idx < curr_head_num; ++head_idx) {
float final_max = local_max[head_idx];
float curr_max = curr_split_max_buffer[head_idx];
float final_sum = local_sum[head_idx];
float curr_sum = curr_split_sum_buffer[head_idx];
float* __restrict__ non_scale_output_iter =
final_max > curr_max ? curr_output_buffer
: curr_split_output_buffer_iter;
float* __restrict__ scale_output_iter =
final_max > curr_max ? curr_split_output_buffer_iter
: curr_output_buffer;
float rescale_factor = final_max > curr_max ? curr_max - final_max
: final_max - curr_max;
rescale_factor = std::exp(rescale_factor);
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
local_sum[head_idx] = final_max > curr_max
? final_sum + rescale_factor * curr_sum
: rescale_factor * final_sum + curr_sum;
final_max = std::max(final_max, curr_max);
local_max[head_idx] = final_max;
for (int32_t i = 0; i < head_dim_group_num; ++i) {
vec_op::FP32Vec16 non_scale_vec(non_scale_output_iter);
vec_op::FP32Vec16 scale_vec(scale_output_iter);
vec_op::FP32Vec16 final_vec =
non_scale_vec + scale_vec * rescale_factor_vec;
final_vec.save(curr_output_buffer);
non_scale_output_iter += 16;
scale_output_iter += 16;
curr_output_buffer += 16;
}
curr_split_output_buffer_iter += head_dim;
}
} else {
vec_op::FP32Vec16 final_max(split_max_buffer);
final_max.save(local_max);
vec_op::FP32Vec16 final_sum(split_sum_buffer);
final_sum.save(local_sum);
}
curr_split_output_buffer += head_num_per_split * head_dim;
curr_split_max_buffer += head_num_per_split;
curr_split_sum_buffer += head_num_per_split;
}
// write back final max and sum
for (int32_t i = 0; i < curr_head_num; ++i) {
split_max_buffer[i] = local_max[i];
split_sum_buffer[i] = local_sum[i];
}
}
void partial_output(float* __restrict__ partial_output_buffer,
float* __restrict__ partial_max_buffer,
float* __restrict__ partial_sum_buffer,
int32_t curr_head_num,
float* __restrict__ split_output_buffer,
float* __restrict__ split_max_buffer,
float* __restrict__ split_sum_buffer,
volatile bool* __restrict__ flag) {
float* __restrict__ curr_partial_output_buffer = partial_output_buffer;
float* __restrict__ curr_split_output_buffer = split_output_buffer;
constexpr int32_t head_dim_group_num = head_dim / 16;
for (int32_t i = 0; i < curr_head_num; ++i) {
split_max_buffer[i] = partial_max_buffer[i];
split_sum_buffer[i] = partial_sum_buffer[i];
for (int32_t j = 0; j < head_dim_group_num; ++j) {
vec_op::FP32Vec16 vec(curr_partial_output_buffer);
vec.save(curr_split_output_buffer);
curr_partial_output_buffer += 16;
curr_split_output_buffer += 16;
}
}
std::atomic_thread_fence(std::memory_order_release);
*flag = true;
}
void final_output(float* __restrict__ partial_q_buffer,
query_t* __restrict__ curr_output_buffer,
float* __restrict__ sum_buffer,
const int32_t q_heads_per_kv,
const int32_t actual_q_token_num,
const int32_t q_head_num) {
// final output
using output_vec_t = typename VecTypeTrait<query_t>::vec_t;
float* __restrict__ curr_partial_output_buffer = partial_q_buffer;
float* __restrict__ curr_sum_buffer = sum_buffer;
constexpr int32_t group_num_per_head = head_dim / 16;
const int32_t partial_q_buffer_stride = q_heads_per_kv * head_dim;
const int32_t output_buffer_stride = q_head_num * head_dim;
for (int32_t token_idx = 0; token_idx < actual_q_token_num; ++token_idx) {
float* __restrict__ curr_partial_output_buffer_iter =
curr_partial_output_buffer;
query_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
vec_op::FP32Vec16 inv_sum_scale_vec(1.0 / *curr_sum_buffer);
for (int32_t i = 0; i < group_num_per_head; ++i) {
vec_op::FP32Vec16 vec(curr_partial_output_buffer_iter);
// divide the final sum val of softmax here
vec = inv_sum_scale_vec * vec;
// cast to query type
output_vec_t output_vec(vec);
output_vec.save(curr_output_buffer_iter);
// update
curr_partial_output_buffer_iter += 16;
curr_output_buffer_iter += 16;
}
// update
curr_sum_buffer += 1;
}
// update
curr_partial_output_buffer += partial_q_buffer_stride;
curr_output_buffer += output_buffer_stride;
}
}
};
} // namespace cpu_attention
#endif
#ifndef CPU_ATTN_NEON_HPP
#define CPU_ATTN_NEON_HPP
#include "cpu_attn_impl.hpp"
#include <arm_neon.h>
#include <type_traits>
#ifdef ARM_BF16_SUPPORT
#include "cpu_attn_neon_bfmmla.hpp"
#endif
namespace cpu_attention {
namespace {
#define BLOCK_SIZE_ALIGNMENT 32
#define HEAD_SIZE_ALIGNMENT 32
#define MAX_Q_HEAD_NUM_PER_ITER 16
// These do not use vectorized class for loading / converting
// because csrc/cpu/cpu_types_arm.hpp does not have fallback options
// for vec_op::BF16Vec* / vec_op::BF16Vec* on Arm HW that
// doesn't support BF16.
// We don't use vec_op::FP32Vec* or vec_op::FP16Vec* for consistency.
template <typename kv_cache_t>
FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, float32x4_t& b0,
float32x4_t& b1);
template <>
FORCE_INLINE void load_row8_B_as_f32<float>(const float* p, float32x4_t& b0,
float32x4_t& b1) {
b0 = vld1q_f32(p + 0);
b1 = vld1q_f32(p + 4);
}
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::Half>(const c10::Half* p,
float32x4_t& b0,
float32x4_t& b1) {
const float16_t* h = reinterpret_cast<const float16_t*>(p);
float16x8_t v = vld1q_f16(h);
b0 = vcvt_f32_f16(vget_low_f16(v));
b1 = vcvt_f32_f16(vget_high_f16(v));
}
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
float32x4_t& b0,
float32x4_t& b1) {
const uint16_t* u = reinterpret_cast<const uint16_t*>(p);
#ifdef ARM_BF16_SUPPORT
uint16x8_t u0 = vld1q_u16(u);
bfloat16x8_t bf0 = vreinterpretq_bf16_u16(u0);
b0 = vcvtq_low_f32_bf16(bf0);
b1 = vcvtq_high_f32_bf16(bf0);
#else
uint16x8_t x0 = vld1q_u16(u);
uint32x4_t lo = vshlq_n_u32(vmovl_u16(vget_low_u16(x0)), 16);
uint32x4_t hi = vshlq_n_u32(vmovl_u16(vget_high_u16(x0)), 16);
b0 = vreinterpretq_f32_u32(lo);
b1 = vreinterpretq_f32_u32(hi);
#endif
}
// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with ASIMD FMLAs
// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
// #FMLAs = (K // 4) * (4 * 2 * M)
// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
template <int32_t M, typename kv_cache_t>
FORCE_INLINE void gemm_micro_neon_fmla_Mx8_Ku4(
const float* __restrict A, // [M x K],
const kv_cache_t* __restrict B, // [K x 8],
float* __restrict C, // [M x 8],
int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
// kernel supports max M of 8, as it'd spill for larger M
static_assert(1 <= M && M <= 8, "M must be in [1,8]");
// helpers for per-M codegen
#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
#define IF_M(i) if constexpr (M > (i))
// A row base pointers
#define DECL_A(i) const float* a##i = A + (i) * lda;
ROWS_APPLY(DECL_A)
#undef DECL_A
// declare 2 accumulators per row of M
#define DECL_ACC(i) float32x4_t acc##i##_0, acc##i##_1;
ROWS_APPLY(DECL_ACC)
#undef DECL_ACC
// initialize accumulators
#define INIT_ACC(i) \
IF_M(i) { \
if (accumulate) { \
acc##i##_0 = vld1q_f32(C + (i) * ldc + 0); \
acc##i##_1 = vld1q_f32(C + (i) * ldc + 4); \
} else { \
acc##i##_0 = vdupq_n_f32(0.f); \
acc##i##_1 = vdupq_n_f32(0.f); \
} \
}
ROWS_APPLY(INIT_ACC)
#undef INIT_ACC
int32_t k = 0;
// K unrolled by 4
for (; k + 3 < K; k += 4) {
// load A[k..k+3] for each active row (M)
#define LOAD_A4(i) \
float32x4_t a##i##v; \
IF_M(i) a##i##v = vld1q_f32(a##i + k);
ROWS_APPLY(LOAD_A4)
#undef LOAD_A4
// helper: FMA lane L from aiv
#define FMAS_LANE(i, aiv, L) \
IF_M(i) { \
acc##i##_0 = vfmaq_laneq_f32(acc##i##_0, b0, aiv, L); \
acc##i##_1 = vfmaq_laneq_f32(acc##i##_1, b1, aiv, L); \
}
// k + 0
{
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 0) * ldb, b0, b1);
#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
ROWS_APPLY(STEP_K0)
#undef STEP_K0
}
// k + 1
{
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 1) * ldb, b0, b1);
#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
ROWS_APPLY(STEP_K1)
#undef STEP_K1
}
// k + 2
{
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 2) * ldb, b0, b1);
#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
ROWS_APPLY(STEP_K2)
#undef STEP_K2
}
// k + 3
{
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 3) * ldb, b0, b1);
#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
ROWS_APPLY(STEP_K3)
#undef STEP_K3
}
#undef FMAS_LANE
}
// K tail
for (; k < K; ++k) {
float32x4_t b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)k * ldb, b0, b1);
#define TAIL_ROW(i) \
IF_M(i) { \
float32x4_t ai = vdupq_n_f32(*(a##i + k)); \
acc##i##_0 = vfmaq_f32(acc##i##_0, b0, ai); \
acc##i##_1 = vfmaq_f32(acc##i##_1, b1, ai); \
}
ROWS_APPLY(TAIL_ROW)
#undef TAIL_ROW
}
// store accumulators to C
#define STORE_ROW(i) \
IF_M(i) { \
vst1q_f32(C + (i) * ldc + 0, acc##i##_0); \
vst1q_f32(C + (i) * ldc + 4, acc##i##_1); \
}
ROWS_APPLY(STORE_ROW)
#undef STORE_ROW
#undef ROWS_APPLY
#undef IF_M
}
template <int32_t N, typename kv_cache_t>
FORCE_INLINE void gemm_macro_neon_fmla_Mx8_Ku4(const float* __restrict A,
const kv_cache_t* __restrict B,
float* __restrict C, int32_t M,
int32_t K, int64_t lda,
int64_t ldb, int64_t ldc,
bool accumulate) {
// micro kernel is Mx8
static_assert(N % 8 == 0, "N must be a multiple of 8");
for (int32_t m = 0; m < M;) {
int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
const float* Ab = A + m * lda;
float* Cb = C + m * ldc;
for (int32_t n = 0; n < N; n += 8) {
const kv_cache_t* Bn = B + n;
float* Cn = Cb + n;
switch (mb) {
case 8:
gemm_micro_neon_fmla_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
K, accumulate);
break;
case 4:
gemm_micro_neon_fmla_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
K, accumulate);
break;
case 2:
gemm_micro_neon_fmla_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
K, accumulate);
break;
default:
gemm_micro_neon_fmla_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
K, accumulate);
break;
}
}
// no tail loop for N as it's guaranteed to be a multiple of 8
m += mb;
}
}
template <typename kv_cache_t>
class TileGemmNeonFMLA {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
if constexpr (phase == AttentionGemmPhase::QK) {
gemm_macro_neon_fmla_Mx8_Ku4<BLOCK_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
} else {
gemm_macro_neon_fmla_Mx8_Ku4<HEAD_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
accum_c);
}
}
};
} // namespace
// this is similar to "ISA::VEC" at the moment
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
BLOCK_SIZE_ALIGNMENT; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
HEAD_SIZE_ALIGNMENT; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::NEON;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
static_assert(HeadDim % HeadDimAlignment == 0);
// the gemm micro kernel is Mx8
static_assert(HeadDimAlignment % 8 == 0);
static_assert(BlockSizeAlignment % 8 == 0);
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemmNeonFMLA<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
// Copy q to q_buffer and cast it to fp32
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
float* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
static_assert(head_dim % 16 == 0);
constexpr int32_t unroll_size = head_dim / 16;
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
scalar_t* __restrict__ curr_q =
src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
float* __restrict__ curr_q_buffer =
q_buffer + q_num_idx * q_heads_per_kv * head_dim +
q_head_idx * head_dim;
vec_op::unroll_loop<int32_t, unroll_size>([&](int32_t i) {
load_vec_t vec(curr_q);
vec_op::FP32Vec16 fp32_vec(vec);
fp32_vec = fp32_vec * scale_vec;
fp32_vec.save(curr_q_buffer);
curr_q += 16;
curr_q_buffer += 16;
});
}
}
}
// reshape K as column-major and V as row-major
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key
const scalar_t* key_start_ptr = key +
token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_cache_start_ptr =
key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_cache_start_ptr[j] = key_start_ptr[i];
}
}
{
// Write Value
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset * head_dim;
std::memcpy(value_cache_start_ptr, value_start_ptr,
sizeof(scalar_t) * head_dim);
}
}
}
}
};
#ifdef ARM_BF16_SUPPORT
// For BF16 on Arm, reuse the BFMMLA kernels with 32-token alignment.
template <int64_t head_dim>
class AttentionImpl<ISA::NEON, c10::BFloat16, head_dim>
: public AttentionImplNEONBFMMLA<BLOCK_SIZE_ALIGNMENT, ISA::NEON,
head_dim> {};
#endif
} // namespace cpu_attention
#undef BLOCK_SIZE_ALIGNMENT
#undef HEAD_SIZE_ALIGNMENT
#undef MAX_Q_HEAD_NUM_PER_ITER
#endif // #ifndef CPU_ATTN_ASIMD_HPP
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#ifndef CPU_ATTN_NEON_BFMMLA_HPP
#define CPU_ATTN_NEON_BFMMLA_HPP
#include "cpu_attn_impl.hpp"
#include <arm_neon.h>
#include <cstdint>
#include <vector>
namespace cpu_attention {
namespace {
// BFMMLA tile dimensions
constexpr int32_t TILE_ROWS = 2; // M dimension
constexpr int32_t TILE_K = 4; // K reduction
constexpr int32_t TILE_COLS = 2; // N dimension (column-pair)
// Derived constants
constexpr int32_t OUTPUT_COLS_PER_BLOCK = 8; // 4 column-pairs
constexpr int32_t K_TOKENS_PER_GROUP = 8; // Tokens grouped in K cache
constexpr int32_t V_TOKENS_PER_ROW_BLOCK = 4; // Tokens per V cache row block
constexpr int32_t K_INNER_STRIDE = K_TOKENS_PER_GROUP * TILE_K;
constexpr int32_t V_INNER_STRIDE = V_TOKENS_PER_ROW_BLOCK * TILE_COLS;
constexpr int32_t PACK_ELEMENTS_PER_K_CHUNK = TILE_ROWS * TILE_K; // A packing
// Matrix Packing and Accumulator
// Reshape two rows of Q into BFMMLA-friendly interleaved
// Input: row0 = [a0,a1,a2,a3], row1 = [b0,b1,b2,b3]
// Output: [a0,a1,a2,a3,b0,b1,b2,b3, a4,a5,a6,a7,b4,b5,b6,b7]
// For K tail (K % TILE_K != 0): pads with zeros to complete the final chunk
FORCE_INLINE void reshape_Q_2xK_for_bfmmla(const c10::BFloat16* __restrict r0,
const c10::BFloat16* __restrict r1,
c10::BFloat16* __restrict dst,
int32_t K) {
const uint16_t* s0 = reinterpret_cast<const uint16_t*>(r0);
const uint16_t* s1 = reinterpret_cast<const uint16_t*>(r1);
uint16_t* d = reinterpret_cast<uint16_t*>(dst);
// Process TILE_K elements at a time (PACK_ELEMENTS_PER_K_CHUNK output)
int32_t k = 0;
for (; k + TILE_K <= K; k += TILE_K, d += PACK_ELEMENTS_PER_K_CHUNK) {
vst1q_u16(d, vcombine_u16(vld1_u16(s0 + k), vld1_u16(s1 + k)));
}
// Handle K tail: pack remaining elements with zero-padding
const int32_t tail = K - k;
if (tail > 0) {
// Pack remaining tail elements: [r0[k..k+tail-1], pad, r1[k..k+tail-1],
// pad]
for (int32_t t = 0; t < tail; ++t) {
d[t] = s0[k + t];
d[t + TILE_K] = s1[k + t];
}
// Zero-pad the rest
for (int32_t t = tail; t < TILE_K; ++t) {
d[t] = 0;
d[t + TILE_K] = 0;
}
}
}
// 2x2 accumulator load/store with compile-time row count
template <int32_t m_rows>
FORCE_INLINE float32x4_t load_acc_2x2(float* base, int64_t ldc, int col_off) {
static_assert(m_rows == 1 || m_rows == 2);
float32x2_t row0 = vld1_f32(base + col_off);
float32x2_t row1 =
(m_rows == 2) ? vld1_f32(base + ldc + col_off) : vdup_n_f32(0.f);
return vcombine_f32(row0, row1);
}
template <int32_t m_rows>
FORCE_INLINE void store_acc_2x2(float32x4_t acc, float* base, int64_t ldc,
int col_off) {
static_assert(m_rows == 1 || m_rows == 2);
vst1_f32(base + col_off, vget_low_f32(acc));
if constexpr (m_rows == 2) {
vst1_f32(base + ldc + col_off, vget_high_f32(acc));
}
}
// Initialize 4 column-pair accumulators for 2 rows (8 columns total)
#define INIT_ACC_ROWPAIR_4(a0, a1, a2, a3, Crow, ldc, m_rows, accum) \
do { \
if (accum) { \
if (m_rows == 2) { \
a0 = load_acc_2x2<2>(Crow, ldc, 0); \
a1 = load_acc_2x2<2>(Crow, ldc, 2); \
a2 = load_acc_2x2<2>(Crow, ldc, 4); \
a3 = load_acc_2x2<2>(Crow, ldc, 6); \
} else { \
a0 = load_acc_2x2<1>(Crow, ldc, 0); \
a1 = load_acc_2x2<1>(Crow, ldc, 2); \
a2 = load_acc_2x2<1>(Crow, ldc, 4); \
a3 = load_acc_2x2<1>(Crow, ldc, 6); \
} \
} else { \
a0 = a1 = a2 = a3 = vdupq_n_f32(0.f); \
} \
} while (0)
// Store 4 column-pair accumulators back to C matrix
#define STORE_ACC_ROWPAIR_4(a0, a1, a2, a3, Crow, ldc, m_rows) \
do { \
if (m_rows == 2) { \
store_acc_2x2<2>(a0, Crow, ldc, 0); \
store_acc_2x2<2>(a1, Crow, ldc, 2); \
store_acc_2x2<2>(a2, Crow, ldc, 4); \
store_acc_2x2<2>(a3, Crow, ldc, 6); \
} else { \
store_acc_2x2<1>(a0, Crow, ldc, 0); \
store_acc_2x2<1>(a1, Crow, ldc, 2); \
store_acc_2x2<1>(a2, Crow, ldc, 4); \
store_acc_2x2<1>(a3, Crow, ldc, 6); \
} \
} while (0)
// Perform 4 BFMMLA operations: acc += A @ B for 4 column-pairs
#define BFMMLA_COMPUTE_4(r0, r1, r2, r3, a, b0, b1, b2, b3) \
do { \
r0 = vbfmmlaq_f32(r0, a, b0); \
r1 = vbfmmlaq_f32(r1, a, b1); \
r2 = vbfmmlaq_f32(r2, a, b2); \
r3 = vbfmmlaq_f32(r3, a, b3); \
} while (0)
// Micro-kernel: updates a small fixed tile using BFMMLA.
// RP = number of row-pairs (1,2,4)
// Computes C[TILE_ROWS*RP, OUTPUT_COLS_PER_BLOCK] += A_packed @ B.
// A_packed interleaves RP row-pairs; B layout is driven by the attention phase:
// - AttentionGemmPhase::QK -> token-column layout (Q @ K^T)
// - AttentionGemmPhase::PV -> token-row layout (P @ V)
// K_static < 0 enables runtime K (PV only)
template <int32_t RP, int32_t K_static, AttentionGemmPhase phase>
FORCE_INLINE void gemm_rowpairs_x8_bfmmla_neon(
const bfloat16_t* const* __restrict A_packed_rp,
const int32_t* __restrict m_rows_rp, const bfloat16_t* __restrict B_blk,
float* __restrict C, int64_t ldc, bool accumulate, int64_t b_stride,
int32_t K_runtime = 0) {
static_assert(RP == 1 || RP == 2 || RP == 4, "RP must be 1,2,4");
static_assert(K_static < 0 || K_static % TILE_K == 0,
"K must be divisible by TILE_K");
static_assert(K_static >= 0 || phase == AttentionGemmPhase::PV,
"Runtime K only supported for PV");
constexpr bool runtime_k = (K_static < 0);
const int32_t K_iters =
runtime_k ? (K_runtime / TILE_K) : (K_static / TILE_K);
const int32_t K_tail = runtime_k ? (K_runtime % TILE_K) : 0;
if (!runtime_k) {
// Help the compiler fold away unused K_runtime when K is compile-time
(void)K_runtime;
}
auto* C_al = C;
const auto* B_al = B_blk;
// Setup A pointers
const bfloat16_t* a_ptr[4] = {
A_packed_rp[0],
(RP >= 2) ? A_packed_rp[1] : nullptr,
(RP >= 4) ? A_packed_rp[2] : nullptr,
(RP >= 4) ? A_packed_rp[3] : nullptr,
};
// Setup B pointers based on layout
const bfloat16_t* b_ptr[4];
if constexpr (phase == AttentionGemmPhase::PV) {
b_ptr[0] = B_blk + 0 * b_stride;
b_ptr[1] = B_blk + 1 * b_stride;
b_ptr[2] = B_blk + 2 * b_stride;
b_ptr[3] = B_blk + 3 * b_stride;
}
float32x4_t acc[4][4];
// Initialize accumulators
#define INIT_RP(rp) \
if constexpr (RP > rp) { \
INIT_ACC_ROWPAIR_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], \
C_al + (rp * 2) * ldc, ldc, m_rows_rp[rp], accumulate); \
}
INIT_RP(0);
INIT_RP(1);
INIT_RP(2);
INIT_RP(3);
#undef INIT_RP
// Main compute loop
for (int32_t ki = 0; ki < K_iters; ++ki) {
bfloat16x8_t b0, b1, b2, b3;
if constexpr (phase == AttentionGemmPhase::PV) {
b0 = vld1q_bf16(b_ptr[0] + ki * V_INNER_STRIDE);
b1 = vld1q_bf16(b_ptr[1] + ki * V_INNER_STRIDE);
b2 = vld1q_bf16(b_ptr[2] + ki * V_INNER_STRIDE);
b3 = vld1q_bf16(b_ptr[3] + ki * V_INNER_STRIDE);
} else {
const bfloat16_t* b_base = B_al + ki * b_stride;
b0 = vld1q_bf16(b_base + 0 * V_INNER_STRIDE);
b1 = vld1q_bf16(b_base + 1 * V_INNER_STRIDE);
b2 = vld1q_bf16(b_base + 2 * V_INNER_STRIDE);
b3 = vld1q_bf16(b_base + 3 * V_INNER_STRIDE);
}
#define COMPUTE_RP(rp) \
if constexpr (RP > rp) { \
bfloat16x8_t a = vld1q_bf16(a_ptr[rp] + ki * PACK_ELEMENTS_PER_K_CHUNK); \
BFMMLA_COMPUTE_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], a, b0, \
b1, b2, b3); \
}
COMPUTE_RP(0);
COMPUTE_RP(1);
COMPUTE_RP(2);
COMPUTE_RP(3);
#undef COMPUTE_RP
}
// K tail for runtime PV: fallback path
if constexpr (runtime_k) {
if (K_tail > 0) {
const int32_t tail_offset = K_iters * V_INNER_STRIDE;
const int32_t a_tail_offset = K_iters * PACK_ELEMENTS_PER_K_CHUNK;
for (int32_t kt = 0; kt < K_tail; ++kt) {
float32x4_t b_vecs[4];
for (int32_t p = 0; p < 4; ++p) {
const bfloat16_t* bp = b_ptr[p] + tail_offset + kt * TILE_COLS;
const float b0 = vcvtah_f32_bf16(bp[0]);
const float b1 = vcvtah_f32_bf16(bp[1]);
const float32x2_t b_pair = vset_lane_f32(b1, vdup_n_f32(b0), 1);
b_vecs[p] = vcombine_f32(b_pair, b_pair);
}
#define TAIL_RP(rp) \
if constexpr (RP > rp) { \
const bfloat16_t* ap = A_packed_rp[rp] + a_tail_offset; \
float a_row0 = vcvtah_f32_bf16(ap[kt]); \
float a_row1 = \
(m_rows_rp[rp] == 2) ? vcvtah_f32_bf16(ap[kt + TILE_K]) : 0.0f; \
const float32x4_t a_vec = \
vcombine_f32(vdup_n_f32(a_row0), vdup_n_f32(a_row1)); \
for (int32_t p = 0; p < 4; ++p) { \
acc[rp][p] = vmlaq_f32(acc[rp][p], a_vec, b_vecs[p]); \
} \
}
TAIL_RP(0);
TAIL_RP(1);
TAIL_RP(2);
TAIL_RP(3);
#undef TAIL_RP
}
}
}
// Store results
#define STORE_RP(rp) \
if constexpr (RP > rp) { \
STORE_ACC_ROWPAIR_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], \
C_al + (rp * 2) * ldc, ldc, m_rows_rp[rp]); \
}
STORE_RP(0);
STORE_RP(1);
STORE_RP(2);
STORE_RP(3);
#undef STORE_RP
}
// Meso-kernel: packs a small MBxK slice of A, then tiles over N and calls the
// micro-kernel for each OUTPUT_COLS_PER_BLOCK chunk. K_static < 0 enables
// runtime K (PV only).
template <int32_t MB, int32_t N, int32_t K_static, AttentionGemmPhase phase>
FORCE_INLINE void gemm_packA_compute_MB_xN(
const c10::BFloat16* __restrict A, const c10::BFloat16* __restrict B,
float* __restrict C, int32_t K_runtime, int64_t lda, int64_t ldc,
int64_t b_layout_stride, int64_t b_reduction_stride, bool accumulate) {
static_assert(MB >= 1 && MB <= 8, "MB must be in [1,8]");
static_assert(N % OUTPUT_COLS_PER_BLOCK == 0,
"N must be a multiple of OUTPUT_COLS_PER_BLOCK");
static_assert(K_static < 0 || K_static % TILE_K == 0,
"K must be divisible by TILE_K");
static_assert(K_static >= 0 || phase == AttentionGemmPhase::PV,
"Runtime K only supported for PV");
constexpr bool runtime_k = (K_static < 0);
const int32_t K_val = runtime_k ? K_runtime : K_static;
// Keep small packs on-stack to avoid heap churn
constexpr int32_t STACK_PACK_STRIDE =
(1024 / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK;
constexpr int32_t ROW_PAIRS = (MB + 1) / TILE_ROWS;
const int32_t pack_stride =
runtime_k ? ((K_val + TILE_K - 1) / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK
: (K_static / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK;
alignas(64) c10::BFloat16 A_packed_stack[ROW_PAIRS * STACK_PACK_STRIDE];
std::vector<c10::BFloat16> A_packed_heap;
c10::BFloat16* A_packed =
(pack_stride <= STACK_PACK_STRIDE)
? A_packed_stack
: (A_packed_heap.resize(ROW_PAIRS * pack_stride),
A_packed_heap.data());
for (int32_t rp = 0; rp < ROW_PAIRS; ++rp) {
const int32_t m = rp * TILE_ROWS;
const int32_t m_rows = (m + 1 < MB) ? TILE_ROWS : 1;
const c10::BFloat16* A0 = A + m * lda;
const c10::BFloat16* A1 = (m_rows == TILE_ROWS) ? (A + (m + 1) * lda) : A0;
reshape_Q_2xK_for_bfmmla(A0, A1, A_packed + rp * pack_stride, K_val);
}
for (int32_t n = 0; n < N; n += OUTPUT_COLS_PER_BLOCK) {
const c10::BFloat16* B_blk_c10 =
(phase == AttentionGemmPhase::PV)
? (B + (n / TILE_COLS) * b_layout_stride)
: (B + (n / OUTPUT_COLS_PER_BLOCK) * b_layout_stride);
const bfloat16_t* B_blk = reinterpret_cast<const bfloat16_t*>(B_blk_c10);
// Process row-pairs in groups of 4, 2, then 1
int32_t row_pair_idx = 0;
#define PROCESS_RP_GROUP(group_size) \
for (; row_pair_idx + (group_size - 1) < ROW_PAIRS; \
row_pair_idx += group_size) { \
const bfloat16_t* Ap[group_size]; \
int32_t mr[group_size]; \
for (int32_t i = 0; i < group_size; ++i) { \
Ap[i] = reinterpret_cast<const bfloat16_t*>( \
A_packed + (row_pair_idx + i) * pack_stride); \
mr[i] = (((row_pair_idx + i) * TILE_ROWS + 1) < MB) ? TILE_ROWS : 1; \
} \
float* C_blk = C + (row_pair_idx * TILE_ROWS) * ldc + n; \
if constexpr (runtime_k) { \
gemm_rowpairs_x8_bfmmla_neon<group_size, -1, phase>( \
Ap, mr, B_blk, C_blk, ldc, accumulate, b_layout_stride, K_val); \
} else { \
gemm_rowpairs_x8_bfmmla_neon<group_size, K_static, phase>( \
Ap, mr, B_blk, C_blk, ldc, accumulate, \
(phase == AttentionGemmPhase::PV) ? b_layout_stride \
: b_reduction_stride); \
} \
}
PROCESS_RP_GROUP(4);
PROCESS_RP_GROUP(2);
PROCESS_RP_GROUP(1);
#undef PROCESS_RP_GROUP
}
}
// Macro-kernel: iterates over M in MB={8,4,2,1} chunks.
// Supports compile-time K specialization when K >= 0; otherwise uses runtime K
// (runtime K path is only supported for PV).
template <AttentionGemmPhase phase, int32_t N, int32_t K = -1>
FORCE_INLINE void gemm_macro_neon_bfmmla(
const c10::BFloat16* __restrict A, const c10::BFloat16* __restrict B,
float* __restrict C, int32_t M, int32_t K_runtime, int64_t lda, int64_t ldc,
int64_t b_layout_stride, int64_t b_reduction_stride, bool accumulate) {
static_assert(N % OUTPUT_COLS_PER_BLOCK == 0,
"N must be a multiple of OUTPUT_COLS_PER_BLOCK");
if constexpr (K >= 0) {
static_assert(K % TILE_K == 0, "K must be divisible by TILE_K");
for (int32_t m = 0; m < M;) {
const int32_t rem = M - m;
const c10::BFloat16* A_blk = A + m * lda;
float* C_blk = C + m * ldc;
#define DISPATCH_MB(mb) \
gemm_packA_compute_MB_xN<mb, N, K, phase>(A_blk, B, C_blk, 0, lda, ldc, \
b_layout_stride, \
b_reduction_stride, accumulate)
if (rem >= 8) {
DISPATCH_MB(8);
m += 8;
} else if (rem >= 4) {
DISPATCH_MB(4);
m += 4;
} else if (rem >= 2) {
DISPATCH_MB(2);
m += 2;
} else {
DISPATCH_MB(1);
m += 1;
}
#undef DISPATCH_MB
}
} else {
static_assert(phase == AttentionGemmPhase::PV,
"Runtime K specialization only supported for PV.");
const int32_t K_val = K_runtime;
for (int32_t m = 0; m < M;) {
const int32_t rem = M - m;
const c10::BFloat16* A_blk = A + m * lda;
float* C_blk = C + m * ldc;
#define DISPATCH_MB_RUNTIME(mb) \
gemm_packA_compute_MB_xN<mb, N, -1, phase>(A_blk, B, C_blk, K_val, lda, ldc, \
b_layout_stride, \
b_reduction_stride, accumulate)
if (rem >= 8) {
DISPATCH_MB_RUNTIME(8);
m += 8;
} else if (rem >= 4) {
DISPATCH_MB_RUNTIME(4);
m += 4;
} else if (rem >= 2) {
DISPATCH_MB_RUNTIME(2);
m += 2;
} else {
DISPATCH_MB_RUNTIME(1);
m += 1;
}
#undef DISPATCH_MB_RUNTIME
}
}
}
#undef INIT_ACC_ROWPAIR_4
#undef STORE_ACC_ROWPAIR_4
#undef BFMMLA_COMPUTE_4
} // namespace
// TileGemm Adapter for Attention
template <typename kv_cache_t, int32_t BlockTokens, int32_t HeadDim>
class TileGemmNEONBFMMLA {
public:
template <AttentionGemmPhase phase, int32_t head_dim_ct>
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
[[maybe_unused]] const int64_t ldb,
const int64_t ldc,
[[maybe_unused]] const int32_t block_size,
[[maybe_unused]] const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(BlockTokens % OUTPUT_COLS_PER_BLOCK == 0);
// BFMMLA kernels require compile-time head_dim; keep head_dim_ct only for
// API parity with other tile_gemm implementations.
if constexpr (head_dim_ct >= 0) {
static_assert(head_dim_ct == HeadDim,
"BFMMLA expects head_dim_ct to match HeadDim; PV passes "
"-1 for API parity.");
}
if constexpr (phase == AttentionGemmPhase::QK) {
const int64_t b_reduction_stride = K_INNER_STRIDE;
const int64_t b_token_block_stride = (HeadDim / TILE_K) * K_INNER_STRIDE;
gemm_macro_neon_bfmmla<AttentionGemmPhase::QK, BlockTokens, HeadDim>(
reinterpret_cast<const c10::BFloat16*>(a_tile), b_tile, c_tile,
m_size, 0, lda, ldc, b_token_block_stride, b_reduction_stride,
accum_c);
} else {
const int64_t b_pair_stride =
(block_size / V_TOKENS_PER_ROW_BLOCK) * V_INNER_STRIDE;
// PV gemm with runtime K specialization
switch (dynamic_k_size) {
case 32:
gemm_macro_neon_bfmmla<AttentionGemmPhase::PV, HeadDim, 32>(
reinterpret_cast<const c10::BFloat16*>(a_tile), b_tile, c_tile,
m_size, 32, lda, ldc, b_pair_stride, 0, accum_c);
break;
case 128:
gemm_macro_neon_bfmmla<AttentionGemmPhase::PV, HeadDim, 128>(
reinterpret_cast<const c10::BFloat16*>(a_tile), b_tile, c_tile,
m_size, 128, lda, ldc, b_pair_stride, 0, accum_c);
break;
case 256:
gemm_macro_neon_bfmmla<AttentionGemmPhase::PV, HeadDim, 256>(
reinterpret_cast<const c10::BFloat16*>(a_tile), b_tile, c_tile,
m_size, 256, lda, ldc, b_pair_stride, 0, accum_c);
break;
default:
gemm_macro_neon_bfmmla<AttentionGemmPhase::PV, HeadDim>(
reinterpret_cast<const c10::BFloat16*>(a_tile), b_tile, c_tile,
m_size, dynamic_k_size, lda, ldc, b_pair_stride, 0, accum_c);
break;
}
}
}
};
// Shared ASIMD BFMMLA implementation (BF16 only). The block size alignment and
// ISA tag are template parameters so we can reuse the same kernels for
// different NEON configurations.
template <int64_t block_size_alignment, ISA isa_type, int64_t head_dim>
class AttentionImplNEONBFMMLA {
public:
using query_t = c10::BFloat16;
using q_buffer_t = c10::BFloat16;
using kv_cache_t = c10::BFloat16;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = c10::BFloat16;
static constexpr int64_t BlockSizeAlignment = block_size_alignment;
// HeadDimAlignment equals head_dim so that the PV phase processes
// the full head dimension in a single gemm call.
static constexpr int64_t HeadDimAlignment = head_dim;
static constexpr int64_t MaxQHeadNumPerIteration = 16;
static constexpr int64_t HeadDim = head_dim;
static constexpr ISA ISAType = isa_type;
static constexpr bool scale_on_logits = false;
static_assert(HeadDim % OUTPUT_COLS_PER_BLOCK == 0);
static_assert(BlockSizeAlignment % OUTPUT_COLS_PER_BLOCK == 0);
static_assert(HeadDim % TILE_K == 0, "HeadDim must be a multiple of TILE_K");
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<
TileGemmNEONBFMMLA<kv_cache_t, static_cast<int32_t>(BlockSizeAlignment),
static_cast<int32_t>(HeadDim)>>
attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// Key cache stride per token group (TokenColumn layout; QK)
static constexpr int64_t k_cache_token_group_stride(
[[maybe_unused]] const int32_t block_size) {
static_assert(BlockSizeAlignment % K_TOKENS_PER_GROUP == 0);
return (BlockSizeAlignment / K_TOKENS_PER_GROUP) *
((head_dim / TILE_K) * K_INNER_STRIDE);
}
// Value cache stride per token group (TokenRow layout; PV)
static constexpr int64_t v_cache_token_group_stride(
[[maybe_unused]] const int32_t block_size) {
static_assert(BlockSizeAlignment % V_TOKENS_PER_ROW_BLOCK == 0);
return (BlockSizeAlignment / V_TOKENS_PER_ROW_BLOCK) * V_INNER_STRIDE;
}
// The stride to move to the "next" head_dim group
// is the full V cache size per head, since HeadDimAlignment == head_dim.
// Hence, the stride is not used in this case
static constexpr int64_t v_cache_head_group_stride(
[[maybe_unused]] const int32_t block_size) {
return head_dim * block_size;
}
// Convert Q heads to BF16 and apply scale factor using native BF16 intrinsics
static void copy_q_heads_tile(c10::BFloat16* __restrict__ src,
c10::BFloat16* __restrict__ q_buffer,
const int32_t q_num,
const int32_t q_heads_per_kv,
const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
constexpr int32_t dim = static_cast<int32_t>(head_dim);
const float32x4_t scale_vec = vdupq_n_f32(scale);
for (int32_t qi = 0; qi < q_num; ++qi) {
for (int32_t hi = 0; hi < q_heads_per_kv; ++hi) {
c10::BFloat16* __restrict__ curr_q =
src + qi * q_num_stride + hi * q_head_stride;
c10::BFloat16* __restrict__ dst =
q_buffer + qi * q_heads_per_kv * head_dim + hi * head_dim;
for (int32_t i = 0; i < dim; i += OUTPUT_COLS_PER_BLOCK) {
bfloat16x8_t in8 =
vld1q_bf16(reinterpret_cast<const bfloat16_t*>(curr_q + i));
float32x4_t lo = vmulq_f32(vcvtq_low_f32_bf16(in8), scale_vec);
float32x4_t hi = vmulq_f32(vcvtq_high_f32_bf16(in8), scale_vec);
bfloat16x4_t lo_b = vcvt_bf16_f32(lo);
bfloat16x4_t hi_b = vcvt_bf16_f32(hi);
bfloat16x8_t out = vcombine_bf16(lo_b, hi_b);
vst1q_bf16(reinterpret_cast<bfloat16_t*>(dst + i), out);
}
}
}
}
public:
// Reshape and cache K/V into BFMMLA-optimized layouts
// K cache:
// [block_size/K_TOKENS_PER_GROUP][head_dim/TILE_K][K_INNER_STRIDE]
// - TokenColumn
// V cache:
// [head_dim/TILE_COLS][block_size/V_TOKENS_PER_ROW_BLOCK][V_INNER_STRIDE]
// - TokenRows
static void reshape_and_cache(
const c10::BFloat16* __restrict__ key,
const c10::BFloat16* __restrict__ value,
c10::BFloat16* __restrict__ key_cache,
c10::BFloat16* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride,
[[maybe_unused]] const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size,
[[maybe_unused]] const int64_t block_size_stride) {
const int64_t k_block_stride = (head_dim / TILE_K) * K_INNER_STRIDE;
const int64_t v_pair_stride =
(block_size / V_TOKENS_PER_ROW_BLOCK) * V_INNER_STRIDE;
#pragma omp parallel for
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) continue;
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
// Key cache: TokenColumn QK
{
const c10::BFloat16* __restrict key_src =
key + token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
c10::BFloat16* __restrict key_base = key_cache +
block_idx * num_blocks_stride +
head_idx * cache_head_num_stride;
const int64_t block_in_block = block_offset / K_TOKENS_PER_GROUP;
const int64_t pair_in_block =
(block_offset % K_TOKENS_PER_GROUP) / TILE_COLS;
const int64_t lane_base = (block_offset & 1) ? TILE_K : 0;
c10::BFloat16* __restrict block_base =
key_base + block_in_block * k_block_stride;
for (int64_t hd4 = 0; hd4 < head_dim / TILE_K; ++hd4) {
uint16_t* dst_u16 = reinterpret_cast<uint16_t*>(
block_base + hd4 * K_INNER_STRIDE +
pair_in_block * V_INNER_STRIDE + lane_base);
const uint16_t* src_u16 =
reinterpret_cast<const uint16_t*>(key_src + hd4 * TILE_K);
vst1_u16(dst_u16, vld1_u16(src_u16));
}
}
// Value cache: TokenRow PV
{
const c10::BFloat16* __restrict value_src =
value + token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
c10::BFloat16* __restrict value_base =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride;
const int64_t row_block = block_offset / V_TOKENS_PER_ROW_BLOCK;
const int64_t lane = block_offset & (V_TOKENS_PER_ROW_BLOCK - 1);
c10::BFloat16* __restrict row_block_base =
value_base + row_block * V_INNER_STRIDE;
for (int64_t hd2 = 0; hd2 < head_dim / TILE_COLS; ++hd2) {
c10::BFloat16* __restrict dst_val =
row_block_base + hd2 * v_pair_stride;
const uint16_t* src_u16 =
reinterpret_cast<const uint16_t*>(value_src);
uint16_t* dst_u16 = reinterpret_cast<uint16_t*>(dst_val);
dst_u16[lane] = src_u16[hd2 * TILE_COLS + 0];
dst_u16[lane + V_TOKENS_PER_ROW_BLOCK] =
src_u16[hd2 * TILE_COLS + 1];
}
}
}
}
}
};
} // namespace cpu_attention
#endif // CPU_ATTN_ASIMD_BFMMLA_HPP
#ifndef CPU_ATTN_VEC_HPP
#define CPU_ATTN_VEC_HPP
#include "cpu_attn_impl.hpp"
namespace cpu_attention {
namespace {
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
template <typename kv_cache_t>
class TileGemm82 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
switch (m_size) {
case 1:
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 2:
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 3:
case 4:
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 5:
case 6:
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 7:
case 8:
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
}
}
template <int32_t M>
static void gemm_micro(float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size, const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(0 < M <= 8);
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
kv_cache_t* __restrict__ curr_b_0 = b_tile;
kv_cache_t* __restrict__ curr_b_1 = b_tile + 16;
float* __restrict__ curr_c_0 = c_tile;
float* __restrict__ curr_c_1 = c_tile + 16;
vec_op::FP32Vec16 c_regs[M * 2];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
float* __restrict__ curr_m_c_1 = curr_c_1;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
// update
curr_m_c_0 += ldc;
curr_m_c_1 += ldc;
});
}
float* __restrict__ curr_a = a_tile;
for (int32_t k = 0; k < dynamic_k_size; ++k) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
load_vec_t b_1_reg(curr_b_1);
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
float* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
float v = *curr_m_a;
vec_op::FP32Vec16 a_reg(v);
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += ldb;
curr_b_1 += ldb;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2].save(curr_c_0);
c_regs[i * 2 + 1].save(curr_c_1);
// update
curr_c_0 += ldc;
curr_c_1 += ldc;
});
}
};
} // namespace
// This is a general but naive implementation based on vector instructions
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
32; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
32; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 8;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VEC;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemm82<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
// Copy q to q_buffer and cast it to fp32
static void copy_q_heads_tile(
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
float* __restrict__ q_buffer, const int32_t q_num,
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
static_assert(head_dim % 16 == 0);
constexpr int32_t unroll_size = head_dim / 16;
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
vec_op::FP32Vec16 scale_vec(scale);
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
scalar_t* __restrict__ curr_q =
src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
float* __restrict__ curr_q_buffer =
q_buffer + q_num_idx * q_heads_per_kv * head_dim +
q_head_idx * head_dim;
vec_op::unroll_loop<int32_t, unroll_size>([&](int32_t i) {
load_vec_t vec(curr_q);
vec_op::FP32Vec16 fp32_vec(vec);
fp32_vec = fp32_vec * scale_vec;
fp32_vec.save(curr_q_buffer);
curr_q += 16;
curr_q_buffer += 16;
});
}
}
}
// reshape K as column-major and V as row-major
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) {
// skip
continue;
}
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
// Write Key as column-major
const scalar_t* key_start_ptr = key +
token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_cache_start_ptr =
key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
#pragma GCC unroll 8
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_cache_start_ptr[j] = key_start_ptr[i];
}
}
{
// Write Value as row-major
const scalar_t* value_start_ptr = value +
token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* value_cache_start_ptr =
value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset * head_dim;
std::memcpy(value_cache_start_ptr, value_start_ptr,
sizeof(scalar_t) * head_dim);
}
}
}
}
};
} // namespace cpu_attention
#endif
#ifndef CPU_ATTN_VEC16_HPP
#define CPU_ATTN_VEC16_HPP
#include "cpu_attn_vec.hpp"
namespace cpu_attention {
namespace {
// 16-1-16 pattern, 16 regs for A, 1 regs for B, 16 regs for C, [16, K] @ [k,
// 16]
template <typename kv_cache_t>
class TileGemm161 {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
switch (m_size) {
case 1:
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 2:
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 3:
case 4:
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 5:
case 6:
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 7:
case 8:
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 9:
case 10:
case 11:
case 12:
gemm_micro<12>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
case 13:
case 14:
case 15:
case 16:
gemm_micro<16>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
dynamic_k_size, accum_c);
break;
}
}
template <int32_t M>
static void gemm_micro(float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size, const int32_t dynamic_k_size,
const bool accum_c) {
static_assert(0 < M <= 16);
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
kv_cache_t* __restrict__ curr_b_0 = b_tile;
float* __restrict__ curr_c_0 = c_tile;
vec_op::FP32Vec16 c_regs[M];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i] = vec_op::FP32Vec16(curr_m_c_0);
// update
curr_m_c_0 += ldc;
});
}
float* __restrict__ curr_a = a_tile;
for (int32_t k = 0; k < dynamic_k_size; ++k) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
float* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
float v = *curr_m_a;
vec_op::FP32Vec16 a_reg(v);
c_regs[i] = c_regs[i] + a_reg * fp32_b_0_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += ldb;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i].save(curr_c_0);
// update
curr_c_0 += ldc;
});
}
};
} // namespace
// This is a general but naive implementation based on vector instructions
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VEC16, scalar_t, head_dim>
: public AttentionImpl<ISA::VEC, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment =
16; // KV token num unit of QK and PV phases
constexpr static int64_t HeadDimAlignment =
16; // headdim num unit of PV phase
constexpr static int64_t MaxQHeadNumPerIteration = 16;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VEC16;
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
public:
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemm161<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// k_cache_token_group_stride: stride of K cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
// block_size], row-major
}
// v_cache_token_group_stride: stride of V cache when move to next
// BlockSizeAlignment tokens in a block
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
// head_dim], row-major
}
// v_cache_head_group_stride: stride of V cache when move to next
// HeadDimAlignment head dims in a block
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
// row-major
}
};
} // namespace cpu_attention
#endif
#ifndef CPU_ATTN_VXE_HPP
#define CPU_ATTN_VXE_HPP
#include "cpu_attn_impl.hpp"
#include <vecintrin.h>
#include <type_traits>
namespace cpu_attention {
namespace {
// s390x Vector = 16 bytes (128 bits)
#define BLOCK_SIZE_ALIGNMENT 32
#define HEAD_SIZE_ALIGNMENT 32
#define MAX_Q_HEAD_NUM_PER_ITER 16
template <typename kv_cache_t>
FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, __vector float& b0,
__vector float& b1);
// [1] Float Specialization
template <>
FORCE_INLINE void load_row8_B_as_f32<float>(const float* p, __vector float& b0,
__vector float& b1) {
// Explicitly cast to long long for offset, and float* for pointer
b0 = vec_xl((long long)0, const_cast<float*>(p));
b1 = vec_xl((long long)0, const_cast<float*>(p + 4));
}
// [2] BFloat16 Specialization (Big Endian Fix)
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
__vector float& b0,
__vector float& b1) {
// 1. Load 8 BF16s (16 bytes) into one vector
// Explicit cast to unsigned short* for vec_xl to return vector unsigned short
__vector unsigned short raw = vec_xl((long long)0, (unsigned short*)p);
// 2. Prepare Zero vector
__vector unsigned short zeros = vec_splat_u16(0);
// 3. Merge High/Low to expand BF16 -> Float32
// On Big Endian, a float is [BF16_bits | 16_zero_bits]
b0 = (__vector float)vec_mergeh(raw, zeros);
b1 = (__vector float)vec_mergel(raw, zeros);
}
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::Half>(const c10::Half* p,
__vector float& b0,
__vector float& b1) {
alignas(16) float tmp[8];
// Manual unroll / conversion
tmp[0] = static_cast<float>(p[0]);
tmp[1] = static_cast<float>(p[1]);
tmp[2] = static_cast<float>(p[2]);
tmp[3] = static_cast<float>(p[3]);
tmp[4] = static_cast<float>(p[4]);
tmp[5] = static_cast<float>(p[5]);
tmp[6] = static_cast<float>(p[6]);
tmp[7] = static_cast<float>(p[7]);
// Explicit arguments for intrinsic: (long long offset, float* ptr)
b0 = vec_xl((long long)0, (float*)tmp);
b1 = vec_xl((long long)0, (float*)(tmp + 4));
}
template <int32_t M, typename kv_cache_t>
FORCE_INLINE void gemm_micro_s390x_Mx8_Ku4(
const float* __restrict A, // [M x K]
const kv_cache_t* __restrict B, // [K x 8]
float* __restrict C, // [M x 8]
int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
static_assert(1 <= M && M <= 8, "M must be in [1,8]");
// Helper macros to unroll codegen for M rows
#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
#define IF_M(i) if constexpr (M > (i))
// 1. Define A pointers
#define DECL_A(i) const float* a##i = A + (i) * lda;
ROWS_APPLY(DECL_A)
#undef DECL_A
// 2. Define Accumulators (2 vectors covers 8 columns)
#define DECL_ACC(i) __vector float acc##i##_0, acc##i##_1;
ROWS_APPLY(DECL_ACC)
#undef DECL_ACC
// 3. Initialize Accumulators (Load C or Zero)
#define INIT_ACC(i) \
IF_M(i) { \
if (accumulate) { \
acc##i##_0 = \
vec_xl((long long)0, const_cast<float*>(C + (i) * ldc + 0)); \
acc##i##_1 = \
vec_xl((long long)0, const_cast<float*>(C + (i) * ldc + 4)); \
} else { \
acc##i##_0 = vec_splats(0.0f); \
acc##i##_1 = vec_splats(0.0f); \
} \
}
ROWS_APPLY(INIT_ACC)
#undef INIT_ACC
int32_t k = 0;
for (; k + 3 < K; k += 4) {
// Load 4 values of A for each Row M: A[k...k+3]
#define LOAD_A4(i) \
__vector float a##i##v; \
IF_M(i) a##i##v = vec_xl((long long)0, const_cast<float*>(a##i + k));
ROWS_APPLY(LOAD_A4)
#undef LOAD_A4
// Helper: FMA for specific lane L of A
// s390x: vec_madd(b, vec_splat(a, lane), acc)
#define FMAS_LANE(i, aiv, L) \
IF_M(i) { \
__vector float a_broad = vec_splat(aiv, L); \
acc##i##_0 = vec_madd(b0, a_broad, acc##i##_0); \
acc##i##_1 = vec_madd(b1, a_broad, acc##i##_1); \
}
// Unroll K=0..3
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 0) * ldb, b0, b1);
#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
ROWS_APPLY(STEP_K0)
#undef STEP_K0
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 1) * ldb, b0, b1);
#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
ROWS_APPLY(STEP_K1)
#undef STEP_K1
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 2) * ldb, b0, b1);
#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
ROWS_APPLY(STEP_K2)
#undef STEP_K2
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 3) * ldb, b0, b1);
#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
ROWS_APPLY(STEP_K3)
#undef STEP_K3
}
#undef FMAS_LANE
}
for (; k < K; ++k) {
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)k * ldb, b0, b1);
#define TAIL_ROW(i) \
IF_M(i) { \
__vector float ai = vec_splats(*(a##i + k)); \
acc##i##_0 = vec_madd(b0, ai, acc##i##_0); \
acc##i##_1 = vec_madd(b1, ai, acc##i##_1); \
}
ROWS_APPLY(TAIL_ROW)
#undef TAIL_ROW
}
#define STORE_ROW(i) \
IF_M(i) { \
vec_xst(acc##i##_0, 0, C + (i) * ldc + 0); \
vec_xst(acc##i##_1, 0, C + (i) * ldc + 4); \
}
ROWS_APPLY(STORE_ROW)
#undef STORE_ROW
#undef ROWS_APPLY
#undef IF_M
}
template <int32_t N, typename kv_cache_t>
FORCE_INLINE void gemm_macro_s390x_Mx8_Ku4(const float* __restrict A,
const kv_cache_t* __restrict B,
float* __restrict C, int32_t M,
int32_t K, int64_t lda, int64_t ldb,
int64_t ldc, bool accumulate) {
static_assert(N % 8 == 0, "N must be a multiple of 8");
for (int32_t m = 0; m < M;) {
int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
const float* Ab = A + m * lda;
float* Cb = C + m * ldc;
for (int32_t n = 0; n < N; n += 8) {
const kv_cache_t* Bn = B + n;
float* Cn = Cb + n;
switch (mb) {
case 8:
gemm_micro_s390x_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
case 4:
gemm_micro_s390x_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
case 2:
gemm_micro_s390x_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
default:
gemm_micro_s390x_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
}
}
m += mb;
}
}
template <typename kv_cache_t>
class TileGemmS390X {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
if constexpr (phase == AttentionGemmPhase::QK) {
gemm_macro_s390x_Mx8_Ku4<BLOCK_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
} else {
gemm_macro_s390x_Mx8_Ku4<HEAD_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
accum_c);
}
}
};
} // namespace
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VXE, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment = BLOCK_SIZE_ALIGNMENT;
constexpr static int64_t HeadDimAlignment = HEAD_SIZE_ALIGNMENT;
constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VXE;
constexpr static bool scale_on_logits =
false; // Scale is applied to Q during copy
public:
AttentionImpl() {}
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemmS390X<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// Strides for Memory Layout
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // [head_dim, block_size] layout
}
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment;
}
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment;
}
static void copy_q_heads_tile(scalar_t* __restrict__ src,
float* __restrict__ q_buffer,
const int32_t q_num,
const int32_t q_heads_per_kv,
const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
__vector float scale_vec = vec_splats(scale);
constexpr bool is_bf16 = std::is_same<scalar_t, c10::BFloat16>::value;
// Process 8 elements at a time (32 bytes of float output)
for (int32_t i = 0; i < q_num; ++i) {
for (int32_t h = 0; h < q_heads_per_kv; ++h) {
scalar_t* curr_src = src + i * q_num_stride + h * q_head_stride;
float* curr_dst =
q_buffer + i * q_heads_per_kv * head_dim + h * head_dim;
int32_t d = 0;
for (; d <= head_dim - 8; d += 8) {
if constexpr (is_bf16) {
__vector float v0, v1;
// Reuse our Big-Endian-Safe loader
load_row8_B_as_f32<scalar_t>(curr_src + d, v0, v1);
v0 = vec_mul(v0, scale_vec);
v1 = vec_mul(v1, scale_vec);
vec_xst(v0, 0, curr_dst + d);
vec_xst(v1, 0, curr_dst + d + 4);
} else {
__vector float v0 = vec_xl((long long)0, (float*)curr_src + d);
__vector float v1 = vec_xl((long long)0, (float*)curr_src + d + 4);
v0 = vec_mul(v0, scale_vec);
v1 = vec_mul(v1, scale_vec);
vec_xst(v0, 0, curr_dst + d);
vec_xst(v1, 0, curr_dst + d + 4);
}
}
for (; d < head_dim; ++d) {
float val = static_cast<float>(curr_src[d]);
curr_dst[d] = val * scale;
}
}
}
}
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) continue;
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
const scalar_t* key_src = key + token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_dst = key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_dst[j] = key_src[i];
}
}
{
const scalar_t* val_src = value + token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* val_dst = value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride +
block_offset * head_dim;
std::memcpy(val_dst, val_src, sizeof(scalar_t) * head_dim);
}
}
}
}
};
} // namespace cpu_attention
#undef BLOCK_SIZE_ALIGNMENT
#undef HEAD_SIZE_ALIGNMENT
#undef MAX_Q_HEAD_NUM_PER_ITER
#endif
\ No newline at end of file
#include "cpu/cpu_types.hpp"
#include "cpu/utils.hpp"
#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
#include "cpu/cpu_arch_macros.h"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
#define AMX_DISPATCH(...) \
case cpu_utils::ISA::AMX: { \
using gemm_t = cpu_micro_gemm::MicroGemm<cpu_utils::ISA::AMX, scalar_t>; \
return __VA_ARGS__(); \
}
#else
#define AMX_DISPATCH(...) case cpu_utils::ISA::AMX:
#endif
#define CPU_ISA_DISPATCH_IMPL(ISA_TYPE, ...) \
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
case cpu_utils::ISA::VEC: { \
using gemm_t = \
cpu_micro_gemm::MicroGemm<cpu_utils::ISA::VEC, scalar_t>; \
return __VA_ARGS__(); \
} \
default: { \
TORCH_CHECK(false, "Invalid CPU ISA type."); \
} \
} \
}()
namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") {
return FusedMOEAct::SiluAndMul;
} else if (act == "swigluoai") {
return FusedMOEAct::SwigluOAIAndMul;
} else {
TORCH_CHECK(false, "Invalid act type: " + act);
}
}
template <typename scalar_t>
void swigluoai_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride,
const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
// For GPT-OSS interleaved gate-up weights
alignas(64) static int32_t index[16] = {0, 2, 4, 6, 8, 10, 12, 14,
16, 18, 20, 22, 24, 26, 28, 30};
vec_op::INT32Vec16 index_vec(index);
vec_op::FP32Vec16 gate_up_max_vec(7.0);
vec_op::FP32Vec16 up_min_vec(-7.0);
vec_op::FP32Vec16 alpha_vec(1.702);
vec_op::FP32Vec16 one_vec(1.0);
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < n_size; n += 32) {
vec_op::FP32Vec16 gate_vec(input + n, index_vec);
vec_op::FP32Vec16 up_vec(input + n + 1, index_vec);
gate_vec = gate_vec.min(gate_up_max_vec);
up_vec = up_vec.clamp(up_min_vec, gate_up_max_vec);
auto sigmoid_vec = one_vec / (one_vec + fast_exp(-gate_vec * alpha_vec));
auto glu = gate_vec * sigmoid_vec;
auto gated_output_fp32 = (one_vec + up_vec) * glu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n / 2);
}
input += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride, const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto sigmoid_vec = one_vec / (one_vec + fast_exp(-gate_vec));
auto silu = gate_vec * sigmoid_vec;
auto gated_output_fp32 = up_vec * silu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input,
scalar_t* __restrict__ output,
const int32_t m, const int32_t n,
const int32_t input_stride,
const int32_t output_stride) {
switch (act) {
case FusedMOEAct::SwigluOAIAndMul:
swigluoai_and_mul(input, output, m, n, input_stride, output_stride);
return;
case FusedMOEAct::SiluAndMul:
silu_and_mul(input, output, m, n, input_stride, output_stride);
return;
default:
TORCH_CHECK(false, "Unsupported act type.");
}
}
template <typename scalar_t, typename gemm_t>
void prepack_moe_weight_impl(scalar_t* __restrict__ weight_ptr,
scalar_t* __restrict__ packed_weight_ptr,
const int32_t expert_num,
const int32_t output_size,
const int32_t input_size,
const int64_t expert_stride) {
#pragma omp parallel for
for (int32_t e_idx = 0; e_idx < expert_num; ++e_idx) {
gemm_t::pack_weight(weight_ptr + expert_stride * e_idx,
packed_weight_ptr + expert_stride * e_idx, output_size,
input_size);
}
}
template <typename scalar_t, typename w_t, typename gemm_t>
void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
w_t* __restrict__ w13, w_t* __restrict__ w2,
w_t* __restrict__ w13_bias, w_t* __restrict__ w2_bias,
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_id, FusedMOEAct act_type,
const int32_t token_num, const int32_t expert_num,
const int32_t topk_num, const int32_t input_size_13,
const int32_t output_size_13, const int32_t input_size_2,
const int32_t output_size_2, const bool skip_weighted) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
constexpr int32_t min_w13_n_tile_size = 2 * gemm_n_tile_size;
static_assert(gemm_n_tile_size % 16 == 0);
TORCH_CHECK_EQ(output_size_13 % min_w13_n_tile_size, 0);
TORCH_CHECK_EQ(output_size_2 % gemm_n_tile_size, 0);
TORCH_CHECK_EQ(output_size_13 / 2, input_size_2);
const int32_t thread_num = omp_get_max_threads();
const int32_t w13_input_buffer_size = cpu_utils::round_up<64>(
gemm_m_tile_size * input_size_13 * sizeof(scalar_t));
const int32_t w13_n_tile_size = [&]() {
const int64_t cache_size = cpu_utils::get_available_l2_size();
// input buffer + output buffer + weight
const int32_t n_size_cache_limit =
(cache_size - w13_input_buffer_size) /
(gemm_m_tile_size * sizeof(float) + input_size_13 * sizeof(scalar_t));
const int32_t n_size_thread_limit =
output_size_13 / std::max(1, thread_num / topk_num);
const int32_t n_size = cpu_utils::round_down<min_w13_n_tile_size>(
std::min(n_size_cache_limit, n_size_thread_limit));
return std::max(n_size, min_w13_n_tile_size);
}();
const int32_t w2_input_tile_size = cpu_utils::round_up<64>(
gemm_m_tile_size * input_size_2 * sizeof(scalar_t));
const int32_t w2_n_tile_size = [&]() {
const int64_t cache_size = cpu_utils::get_available_l2_size();
// input tile + weight
const int32_t n_size_cache_limit =
(cache_size - w2_input_tile_size) / (input_size_2 * sizeof(scalar_t));
const int32_t n_size_thread_limit =
output_size_2 / std::max(1, thread_num / topk_num);
const int32_t n_size = cpu_utils::round_down<gemm_n_tile_size>(
std::min(n_size_cache_limit, n_size_thread_limit));
return std::max(n_size, gemm_n_tile_size);
}();
// allocate buffers
int32_t common_buffer_offset = 0;
int32_t w13_thread_buffer_offset = 0;
int32_t ws_thread_buffer_offset = 0;
// common buffers
const int32_t token_num_per_group_buffer_size =
cpu_utils::round_up<64>(expert_num * sizeof(int32_t));
const int32_t token_num_per_group_buffer_offset = common_buffer_offset;
common_buffer_offset += token_num_per_group_buffer_size;
const int32_t cu_token_num_per_group_buffer_size =
cpu_utils::round_up<64>((expert_num + 1) * sizeof(int32_t));
const int32_t cu_token_num_per_group_buffer_offset = common_buffer_offset;
common_buffer_offset += cu_token_num_per_group_buffer_size;
const int32_t expand_token_id_buffer_size =
cpu_utils::round_up<64>(token_num * topk_num * sizeof(int32_t));
const int32_t expand_token_id_buffer_offset = common_buffer_offset;
common_buffer_offset += expand_token_id_buffer_size;
const int32_t expand_token_id_index_buffer_size =
cpu_utils::round_up<64>(token_num * topk_num * sizeof(int32_t));
const int32_t expand_token_id_index_buffer_offset = common_buffer_offset;
common_buffer_offset += expand_token_id_index_buffer_size;
const int32_t w13_gemm_output_buffer_size = cpu_utils::round_up<64>(
token_num * topk_num * (output_size_13 / 2) * sizeof(scalar_t));
const int32_t w13_gemm_output_buffer_offset = common_buffer_offset;
common_buffer_offset += w13_gemm_output_buffer_size;
const int32_t w2_gemm_output_buffer_size = cpu_utils::round_up<64>(
token_num * topk_num * output_size_2 * sizeof(float));
const int32_t w2_gemm_output_buffer_offset = common_buffer_offset;
common_buffer_offset += w2_gemm_output_buffer_size;
// w13 GEMM thread buffers
const int32_t w13_input_buffer_offset = w13_thread_buffer_offset;
w13_thread_buffer_offset += w13_input_buffer_size;
const int32_t w13_output_buffer_size = cpu_utils::round_up<64>(
gemm_m_tile_size * w13_n_tile_size * sizeof(float));
const int32_t w13_output_buffer_offset = w13_thread_buffer_offset;
w13_thread_buffer_offset += w13_output_buffer_size;
// Weighted sum thread buffer
const int32_t ws_output_buffer_size =
cpu_utils::round_up<64>(output_size_2 * sizeof(float));
const int32_t ws_output_buffer_offset = ws_thread_buffer_offset;
ws_thread_buffer_offset += ws_output_buffer_size;
const int32_t buffer_size =
common_buffer_offset +
std::max(w13_thread_buffer_offset, ws_thread_buffer_offset) * thread_num;
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size);
uint8_t* common_buffer_start =
cpu_utils::ScratchPadManager::get_scratchpad_manager()
->get_data<uint8_t>();
uint8_t* thread_buffer_start = common_buffer_start + common_buffer_offset;
int32_t* __restrict__ token_num_per_group_buffer = reinterpret_cast<int32_t*>(
common_buffer_start + token_num_per_group_buffer_offset);
int32_t* __restrict__ cu_token_num_per_group_buffer =
reinterpret_cast<int32_t*>(common_buffer_start +
cu_token_num_per_group_buffer_offset);
int32_t* __restrict__ expand_token_id_buffer = reinterpret_cast<int32_t*>(
common_buffer_start + expand_token_id_buffer_offset);
int32_t* __restrict__ expand_token_id_index_buffer =
reinterpret_cast<int32_t*>(common_buffer_start +
expand_token_id_index_buffer_offset);
// prepare token-expert mappings
{
std::memset(token_num_per_group_buffer, 0, expert_num * sizeof(int32_t));
for (int32_t i = 0; i < token_num * topk_num; ++i) {
int32_t curr_expert_id = topk_id[i];
++token_num_per_group_buffer[curr_expert_id];
}
int32_t token_num_sum = 0;
cu_token_num_per_group_buffer[0] = 0;
int32_t* token_index_buffer = cu_token_num_per_group_buffer + 1;
for (int32_t i = 0; i < expert_num; ++i) {
token_index_buffer[i] = token_num_sum;
token_num_sum += token_num_per_group_buffer[i];
}
for (int32_t i = 0; i < token_num; ++i) {
int32_t* curr_topk_id = topk_id + i * topk_num;
int32_t* curr_index_buffer = expand_token_id_index_buffer + i * topk_num;
for (int32_t j = 0; j < topk_num; ++j) {
int32_t curr_expert_id = curr_topk_id[j];
int32_t curr_index = token_index_buffer[curr_expert_id];
++token_index_buffer[curr_expert_id];
expand_token_id_buffer[curr_index] = i;
curr_index_buffer[j] = curr_index;
}
}
}
// w13 GEMM + act
{
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
#pragma omp parallel for schedule(static, 1)
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
const int32_t task_num_per_expert =
(output_size_13 + w13_n_tile_size - 1) / w13_n_tile_size;
const int32_t task_num = task_num_per_expert * expert_num;
uint8_t* __restrict__ thread_buffer =
thread_buffer_start + thread_id * w13_thread_buffer_offset;
scalar_t* __restrict__ w13_input_buffer =
reinterpret_cast<scalar_t*>(thread_buffer + w13_input_buffer_offset);
float* __restrict__ w13_output_buffer =
reinterpret_cast<float*>(thread_buffer + w13_output_buffer_offset);
scalar_t* __restrict__ w13_gemm_output_buffer =
reinterpret_cast<scalar_t*>(common_buffer_start +
w13_gemm_output_buffer_offset);
gemm_t gemm;
const int32_t input_size_13_bytes = input_size_13 * sizeof(scalar_t);
const int32_t w13_n_group_stride = 16 * input_size_13;
const int32_t w13_n_tile_stride = gemm_n_tile_size * input_size_13;
for (;;) {
int32_t task_id = counter_ptr->acquire_counter();
if (task_id >= task_num) {
break;
}
const int32_t curr_expert_id = task_id / task_num_per_expert;
const int32_t curr_output_group_id = task_id % task_num_per_expert;
const int32_t curr_token_num =
token_num_per_group_buffer[curr_expert_id];
if (curr_token_num == 0) {
continue;
}
const int32_t actual_n_tile_size =
std::min(w13_n_tile_size,
output_size_13 - curr_output_group_id * w13_n_tile_size);
const int32_t* __restrict__ curr_expand_token_id_buffer =
expand_token_id_buffer +
cu_token_num_per_group_buffer[curr_expert_id];
scalar_t* __restrict__ curr_w13_gemm_output_buffer =
w13_gemm_output_buffer +
cu_token_num_per_group_buffer[curr_expert_id] *
(output_size_13 / 2) +
curr_output_group_id * w13_n_tile_size / 2;
w_t* __restrict__ w13_weight_ptr_0 = nullptr;
w_t* __restrict__ w13_weight_ptr_1 = nullptr;
w_t* __restrict__ w13_bias_ptr_0 = nullptr;
w_t* __restrict__ w13_bias_ptr_1 = nullptr;
if (act_type == FusedMOEAct::SwigluOAIAndMul) {
// For SwigluOAIAndMul, up and down weights are interleaved
w13_weight_ptr_0 =
w13 + curr_expert_id * input_size_13 * output_size_13 +
curr_output_group_id * w13_n_tile_size * input_size_13;
w13_weight_ptr_1 =
w13_weight_ptr_0 + actual_n_tile_size / 2 * input_size_13;
if (w13_bias != nullptr) {
w13_bias_ptr_0 = w13_bias + curr_expert_id * output_size_13 +
curr_output_group_id * w13_n_tile_size;
w13_bias_ptr_1 = w13_bias_ptr_0 + actual_n_tile_size / 2;
}
} else {
w13_weight_ptr_0 =
w13 + curr_expert_id * input_size_13 * output_size_13 +
curr_output_group_id * (w13_n_tile_size / 2) * input_size_13;
w13_weight_ptr_1 =
w13_weight_ptr_0 + output_size_13 / 2 * input_size_13;
if (w13_bias != nullptr) {
w13_bias_ptr_0 = w13_bias + curr_expert_id * output_size_13 +
curr_output_group_id * (w13_n_tile_size / 2);
w13_bias_ptr_1 = w13_bias_ptr_0 + output_size_13 / 2;
}
}
scalar_t* __restrict__ curr_w13_input_buffer = w13_input_buffer;
for (int32_t token_idx = 0; token_idx < curr_token_num;
token_idx += gemm_m_tile_size) {
const int32_t actual_token_num =
std::min(gemm_m_tile_size, curr_token_num - token_idx);
// copy inputs
{
scalar_t* __restrict__ curr_w13_input_buffer_iter =
curr_w13_input_buffer;
for (int32_t i = 0; i < actual_token_num; ++i) {
const int32_t curr_token_id = curr_expand_token_id_buffer[i];
int8_t* __restrict__ curr_input_iter = reinterpret_cast<int8_t*>(
input + curr_token_id * input_size_13);
int8_t* __restrict__ curr_output_iter =
reinterpret_cast<int8_t*>(curr_w13_input_buffer_iter);
int32_t j = 0;
for (; j < input_size_13_bytes - 64; j += 64) {
vec_op::INT8Vec64 vec(curr_input_iter);
vec.save(curr_output_iter);
curr_input_iter += 64;
curr_output_iter += 64;
}
vec_op::INT8Vec64 vec(curr_input_iter);
vec.save(curr_output_iter, input_size_13_bytes - j);
// update
curr_w13_input_buffer_iter += input_size_13;
}
// update
curr_expand_token_id_buffer += actual_token_num;
}
// gemm + act
{
scalar_t* __restrict__ w13_weight_ptr_0_iter = w13_weight_ptr_0;
scalar_t* __restrict__ w13_weight_ptr_1_iter = w13_weight_ptr_1;
scalar_t* __restrict__ w13_bias_ptr_0_iter = w13_bias_ptr_0;
scalar_t* __restrict__ w13_bias_ptr_1_iter = w13_bias_ptr_1;
scalar_t* __restrict__ curr_w13_input_buffer_iter =
curr_w13_input_buffer;
float* __restrict__ w13_output_buffer_0_iter = w13_output_buffer;
float* __restrict__ w13_output_buffer_1_iter =
w13_output_buffer + actual_n_tile_size / 2;
for (int32_t i = 0; i < actual_n_tile_size;
i += min_w13_n_tile_size) {
gemm.gemm(curr_w13_input_buffer_iter, w13_weight_ptr_0_iter,
w13_output_buffer_0_iter, actual_token_num,
input_size_13, input_size_13, w13_n_group_stride,
actual_n_tile_size, false);
if (w13_bias != nullptr) {
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
w13_output_buffer_0_iter, w13_output_buffer_0_iter,
w13_bias_ptr_0_iter, actual_token_num, actual_n_tile_size,
actual_n_tile_size);
w13_bias_ptr_0_iter += gemm_n_tile_size;
}
gemm.gemm(curr_w13_input_buffer_iter, w13_weight_ptr_1_iter,
w13_output_buffer_1_iter, actual_token_num,
input_size_13, input_size_13, w13_n_group_stride,
actual_n_tile_size, false);
if (w13_bias != nullptr) {
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
w13_output_buffer_1_iter, w13_output_buffer_1_iter,
w13_bias_ptr_1_iter, actual_token_num, actual_n_tile_size,
actual_n_tile_size);
w13_bias_ptr_1_iter += gemm_n_tile_size;
}
// update
w13_weight_ptr_0_iter += w13_n_tile_stride;
w13_weight_ptr_1_iter += w13_n_tile_stride;
w13_output_buffer_0_iter += gemm_n_tile_size;
w13_output_buffer_1_iter += gemm_n_tile_size;
}
apply_gated_act(act_type, w13_output_buffer,
curr_w13_gemm_output_buffer, actual_token_num,
actual_n_tile_size, actual_n_tile_size,
output_size_13 / 2);
// update
curr_w13_gemm_output_buffer +=
gemm_m_tile_size * (output_size_13 / 2);
}
}
}
}
}
// w2 GEMM
{
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
#pragma omp parallel for schedule(static, 1)
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
const int32_t task_num_per_expert =
(output_size_2 + w2_n_tile_size - 1) / w2_n_tile_size;
const int32_t task_num = task_num_per_expert * expert_num;
scalar_t* __restrict__ w13_gemm_output_buffer =
reinterpret_cast<scalar_t*>(common_buffer_start +
w13_gemm_output_buffer_offset);
float* __restrict__ w2_gemm_output_buffer = reinterpret_cast<float*>(
common_buffer_start + w2_gemm_output_buffer_offset);
gemm_t gemm;
const int32_t w2_n_tile_stride = gemm_n_tile_size * input_size_2;
const int32_t w2_n_group_stride = 16 * input_size_2;
for (;;) {
int32_t task_id = counter_ptr->acquire_counter();
if (task_id >= task_num) {
break;
}
const int32_t curr_expert_id = task_id / task_num_per_expert;
const int32_t curr_output_group_id = task_id % task_num_per_expert;
const int32_t curr_token_num =
token_num_per_group_buffer[curr_expert_id];
if (curr_token_num == 0) {
continue;
}
const int32_t actual_n_tile_size =
std::min(w2_n_tile_size,
output_size_2 - curr_output_group_id * w2_n_tile_size);
scalar_t* __restrict__ curr_w13_gemm_output_buffer =
w13_gemm_output_buffer +
cu_token_num_per_group_buffer[curr_expert_id] * input_size_2;
float* __restrict__ curr_w2_gemm_output_buffer =
w2_gemm_output_buffer +
cu_token_num_per_group_buffer[curr_expert_id] * output_size_2 +
curr_output_group_id * w2_n_tile_size;
scalar_t* __restrict__ w2_weight_ptr =
w2 + curr_expert_id * output_size_2 * input_size_2 +
curr_output_group_id * w2_n_tile_size * input_size_2;
scalar_t* __restrict__ w2_bias_ptr = nullptr;
if (w2_bias != nullptr) {
w2_bias_ptr = w2_bias + curr_expert_id * output_size_2 +
curr_output_group_id * w2_n_tile_size;
}
for (int32_t token_idx = 0; token_idx < curr_token_num;
token_idx += gemm_m_tile_size) {
const int32_t actual_token_num =
std::min(gemm_m_tile_size, curr_token_num - token_idx);
scalar_t* __restrict__ w2_weight_ptr_iter = w2_weight_ptr;
scalar_t* __restrict__ w2_bias_ptr_iter = w2_bias_ptr;
float* __restrict__ curr_w2_gemm_output_buffer_iter =
curr_w2_gemm_output_buffer;
for (int32_t i = 0; i < actual_n_tile_size; i += gemm_n_tile_size) {
gemm.gemm(curr_w13_gemm_output_buffer, w2_weight_ptr_iter,
curr_w2_gemm_output_buffer_iter, actual_token_num,
input_size_2, input_size_2, w2_n_group_stride,
output_size_2, false);
if (w2_bias != nullptr) {
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
curr_w2_gemm_output_buffer_iter,
curr_w2_gemm_output_buffer_iter, w2_bias_ptr_iter,
actual_token_num, output_size_2, output_size_2);
w2_bias_ptr_iter += gemm_n_tile_size;
}
w2_weight_ptr_iter += w2_n_tile_stride;
curr_w2_gemm_output_buffer_iter += gemm_n_tile_size;
}
// update
curr_w13_gemm_output_buffer += gemm_m_tile_size * input_size_2;
curr_w2_gemm_output_buffer += gemm_m_tile_size * output_size_2;
}
}
}
}
// weighted sum
{
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
#pragma omp parallel for schedule(static, 1)
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
const int32_t task_num = token_num;
uint8_t* __restrict__ thread_buffer =
thread_buffer_start + thread_id * ws_thread_buffer_offset;
float* __restrict__ ws_output_buffer =
reinterpret_cast<float*>(thread_buffer + ws_output_buffer_offset);
float* __restrict__ w2_gemm_output_buffer = reinterpret_cast<float*>(
common_buffer_start + w2_gemm_output_buffer_offset);
for (;;) {
int32_t task_id = counter_ptr->acquire_counter();
if (task_id >= task_num) {
break;
}
int32_t token_id = task_id;
int32_t* __restrict__ curr_expand_token_id_index_buffer =
expand_token_id_index_buffer + token_id * topk_num;
float* __restrict__ curr_weight = topk_weights + token_id * topk_num;
scalar_t* __restrict__ curr_output_buffer =
output + token_id * output_size_2;
if (skip_weighted) {
// Only for topk_num == 1
*curr_weight = 1.0f;
}
if (topk_num > 1) {
{
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
float* __restrict__ w2_output_iter =
w2_gemm_output_buffer + w2_output_idx * output_size_2;
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
vec_op::FP32Vec16 weight_vec(curr_weight[0]);
for (int32_t i = 0; i < output_size_2; i += 16) {
vec_op::FP32Vec16 vec(w2_output_iter);
vec = vec * weight_vec;
vec.save(ws_output_buffer_iter);
// update
w2_output_iter += 16;
ws_output_buffer_iter += 16;
}
}
{
for (int32_t idx = 1; idx < topk_num - 1; ++idx) {
int32_t w2_output_idx = curr_expand_token_id_index_buffer[idx];
float* __restrict__ w2_output_iter =
w2_gemm_output_buffer + w2_output_idx * output_size_2;
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
vec_op::FP32Vec16 weight_vec(curr_weight[idx]);
for (int32_t i = 0; i < output_size_2; i += 16) {
vec_op::FP32Vec16 vec(w2_output_iter);
vec_op::FP32Vec16 sum(ws_output_buffer_iter);
sum = sum + vec * weight_vec;
sum.save(ws_output_buffer_iter);
// update
w2_output_iter += 16;
ws_output_buffer_iter += 16;
}
}
}
{
int32_t idx = topk_num - 1;
int32_t w2_output_idx = curr_expand_token_id_index_buffer[idx];
float* __restrict__ w2_output_iter =
w2_gemm_output_buffer + w2_output_idx * output_size_2;
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
scalar_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
vec_op::FP32Vec16 weight_vec(curr_weight[idx]);
for (int32_t i = 0; i < output_size_2; i += 16) {
vec_op::FP32Vec16 vec(w2_output_iter);
vec_op::FP32Vec16 sum(ws_output_buffer_iter);
sum = sum + vec * weight_vec;
scalar_vec_t out_vec(sum);
out_vec.save(curr_output_buffer_iter);
// update
w2_output_iter += 16;
ws_output_buffer_iter += 16;
curr_output_buffer_iter += 16;
}
}
} else {
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
float* __restrict__ w2_output_iter =
w2_gemm_output_buffer + w2_output_idx * output_size_2;
scalar_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
vec_op::FP32Vec16 weight_vec(curr_weight[0]);
for (int32_t i = 0; i < output_size_2; i += 16) {
vec_op::FP32Vec16 vec(w2_output_iter);
vec = vec * weight_vec;
scalar_vec_t out_vec(vec);
out_vec.save(curr_output_buffer_iter);
// update
w2_output_iter += 16;
curr_output_buffer_iter += 16;
}
}
}
}
}
}
} // namespace
void prepack_moe_weight(
const torch::Tensor& weight, // [expert_num, output_size, input_size]
torch::Tensor& packed_weight, const std::string& isa) {
TORCH_CHECK(weight.is_contiguous());
const int32_t expert_num = weight.size(0);
const int32_t output_size = weight.size(1);
const int32_t input_size = weight.size(2);
TORCH_CHECK_EQ(output_size % 32, 0);
const int64_t expert_stride = weight.stride(0);
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
VLLM_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "prepack_moe_weight", [&]() {
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
scalar_t* weight_ptr = weight.data_ptr<scalar_t>();
scalar_t* packed_weight_ptr = packed_weight.data_ptr<scalar_t>();
prepack_moe_weight_impl<scalar_t, gemm_t>(
weight_ptr, packed_weight_ptr, expert_num, output_size,
input_size, expert_stride);
});
});
}
void cpu_fused_moe(
torch::Tensor& output, // [token_num, output_size_2]
const torch::Tensor& input, // [token_num, input_size_13]
const torch::Tensor&
w13, // [expert_num, output_size_13, input_size_13], packed
const torch::Tensor&
w2, // [expert_num, output_size_2, input_size_2], packed
const std::optional<torch::Tensor>&
w13_bias, // [expert_num, output_size_13]
const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2]
const torch::Tensor& topk_weights, // [token_num, k], float32
const torch::Tensor& topk_id, // [token_num, k], int32
const bool skip_weighted, const std::string& act, const std::string& isa) {
const int32_t token_num = input.size(0);
const int32_t input_size_13 = input.size(1);
const int64_t input_stride = input.stride(0);
TORCH_CHECK_EQ(input_stride, input_size_13);
const int32_t expert_num = w13.size(0);
const int32_t output_size_13 = w13.size(1);
const int32_t input_size_2 = w2.size(2);
const int32_t output_size_2 = w2.size(1);
const int32_t topk_num = topk_id.size(1);
const FusedMOEAct act_type = get_act_type(act);
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
TORCH_CHECK(!skip_weighted || topk_num == 1,
"skip_weighted is only supported for topk=1 on CPU");
VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() {
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
fused_moe_impl<scalar_t, scalar_t, gemm_t>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
w13.data_ptr<scalar_t>(), w2.data_ptr<scalar_t>(),
w13_bias.has_value() ? w13_bias->data_ptr<scalar_t>() : nullptr,
w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr,
topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type,
token_num, expert_num, topk_num, input_size_13, output_size_13,
input_size_2, output_size_2, skip_weighted);
});
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment