Commit 0da7b5db authored by PanZezhong's avatar PanZezhong
Browse files

issue/97 Attention 和 KVCache 支持 batch 维度

parent a1f6e517
#pragma once
#include "infinicore/context/context.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include <algorithm>
#include <memory>
#include <numeric>
#include <stdexcept>
#include <utility>
namespace infinilm::cache {
......@@ -11,26 +15,20 @@ namespace infinilm::cache {
/**
* @brief Simple KV cache structure for incremental decoding
*
* Stores key and value caches with shape [n_kv_head, capacity, head_dim]
* Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim]
* Similar to DynamicLayer in Python cache_utils.py
*
* This is a common component that can be used by any model architecture
* that needs KV caching for attention mechanisms.
*/
struct KVCache {
infinicore::Tensor k_cache; // [n_kv_head, capacity, head_dim]
infinicore::Tensor v_cache; // [n_kv_head, capacity, head_dim]
size_t cache_position; // Current position in cache
infinicore::Tensor k_cache; // [batch_size, n_kv_head, capacity, head_dim]
infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim]
std::vector<size_t> cache_positions; // Current position in cache
size_t max_capacity; // Maximum capacity of cache
bool initialized; // Whether cache has been initialized
KVCache()
: cache_position(0), max_capacity(0), initialized(false),
// Create empty placeholder tensors (will be replaced on first use)
k_cache(infinicore::Tensor::empty({1, 1, 1}, infinicore::DataType::F32,
infinicore::Device(infinicore::Device::Type::CPU, 0))),
v_cache(infinicore::Tensor::empty({1, 1, 1}, infinicore::DataType::F32,
infinicore::Device(infinicore::Device::Type::CPU, 0))) {}
KVCache() : max_capacity(0), initialized(false) {}
/**
* @brief Initialize or update cache capacity
......@@ -40,34 +38,44 @@ struct KVCache {
* @param dtype Data type
* @param device Device
*/
void ensure_capacity(size_t num_kv_heads, size_t head_dim, size_t seq_len,
void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len,
infinicore::DataType dtype, const infinicore::Device &device) {
size_t required_capacity = cache_position + seq_len;
size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); });
// Lazy initialization
if (!initialized) {
max_capacity = std::max(required_capacity, size_t(4096)); // Start with at least 4096
k_cache = infinicore::Tensor::empty({num_kv_heads, max_capacity, head_dim},
k_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device);
v_cache = infinicore::Tensor::empty({num_kv_heads, max_capacity, head_dim},
v_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device);
cache_position = 0;
cache_positions = std::vector<size_t>(batch_size, 0);
initialized = true;
}
// Grow cache if needed (similar to DynamicLayer in Python)
else if (required_capacity > max_capacity) {
size_t new_capacity = std::max(max_capacity * 2, required_capacity);
auto k_new = infinicore::Tensor::empty({num_kv_heads, new_capacity, head_dim},
size_t new_capacity = std::max(max_capacity * 2, required_capacity + max_capacity);
size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]);
if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) {
throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache.");
}
if (new_batch_size > cache_positions.size()) {
cache_positions.resize(new_batch_size, 0);
}
auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device);
auto v_new = infinicore::Tensor::empty({num_kv_heads, new_capacity, head_dim},
auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device);
// Copy existing cache data
for (size_t b = 0; b < new_batch_size; ++b) {
size_t cache_position = cache_positions[b];
if (cache_position > 0) {
auto k_slice = k_cache->narrow({{1, 0, cache_position}});
auto v_slice = v_cache->narrow({{1, 0, cache_position}});
k_new->narrow({{1, 0, cache_position}})->copy_from(k_slice);
v_new->narrow({{1, 0, cache_position}})->copy_from(v_slice);
auto k_slice = k_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
auto v_slice = v_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
k_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(k_slice);
v_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(v_slice);
}
}
k_cache = k_new;
......@@ -76,10 +84,16 @@ struct KVCache {
}
}
KVCache(size_t max_batch_size, size_t n_kv_head, size_t head_dim, infinicore::DataType dtype, size_t max_seqlen = 4096, infinicore::Device device = infinicore::context::getDevice())
: max_capacity(max_seqlen), initialized(false) {
cache_positions = std::vector<size_t>(max_batch_size, 0);
ensure_capacity(max_batch_size, n_kv_head, head_dim, max_capacity, dtype, device);
}
/**
* @brief Update cache with new key and value states
* @param k_new New key states [n_kv_head, seq_len, head_dim]
* @param v_new New value states [n_kv_head, seq_len, head_dim]
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [n_kv_head, total_seq_len, head_dim]
*
* Note: This method writes to the cache. If using with attention op, the attention op
......@@ -88,28 +102,42 @@ struct KVCache {
std::pair<infinicore::Tensor, infinicore::Tensor> update(
const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new) {
size_t seq_len = k_new->shape()[1];
size_t num_kv_heads = k_new->shape()[0];
size_t head_dim = k_new->shape()[2];
if (k_new->ndim() != 4 || v_new->ndim() != 4) {
throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors in [batch_size, n_kv_head, seq_len, head_dim] form.");
}
size_t batch_size = k_new->shape()[0];
size_t num_kv_heads = k_new->shape()[1];
size_t seq_len = k_new->shape()[2];
size_t head_dim = k_new->shape()[3];
// Ensure capacity
ensure_capacity(num_kv_heads, head_dim, seq_len,
ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len,
k_new->dtype(), k_new->device());
// Copy new k/v into cache at current position
auto k_dst = k_cache->narrow({{1, cache_position, seq_len}});
auto v_dst = v_cache->narrow({{1, cache_position, seq_len}});
bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin());
if (all_equal) {
auto cache_position = cache_positions[0];
auto k_dst = k_cache->narrow({{2, cache_position, seq_len}});
auto v_dst = v_cache->narrow({{2, cache_position, seq_len}});
k_dst->copy_from(k_new);
v_dst->copy_from(v_new);
// Update position
cache_position += seq_len;
for (size_t b = 0; b < batch_size; ++b) {
cache_positions[b] = cache_position;
}
// Return the total cache up to current position
auto k_total = k_cache->narrow({{1, 0, cache_position}});
auto v_total = v_cache->narrow({{1, 0, cache_position}});
auto k_total = k_cache->narrow({{2, 0, cache_position}});
auto v_total = v_cache->narrow({{2, 0, cache_position}});
return std::make_pair(k_total->contiguous(), v_total->contiguous());
return std::make_pair(k_total, v_total);
} else {
throw std::runtime_error("KVCache update: cache positions must be equal among a batch.");
}
}
};
......
......@@ -72,15 +72,9 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
throw std::runtime_error("Unexpected position_ids shape");
}
// 4. Apply RoPE to full batch - align with Python pattern
// Python: x = x.view((bs * seq_len, num_heads, head_dim))
// Python asserts: seq_len * x_stride[1] == x_stride[0] (contiguous in dim=0 and dim=1)
// The kernel requires stride(2) == 1 (last dimension contiguous)
// Python's assertion + stride(2) == 1 means the tensor is fully contiguous
// However, to be safe and match Python's behavior exactly, ensure fully contiguous
auto q_for_rope = q_reshaped->view({batch_size * seq_len, num_attention_heads_, head_dim_})->contiguous();
auto k_for_rope = k_reshaped->view({batch_size * seq_len, num_key_value_heads_, head_dim_})->contiguous();
// 4. Apply RoPE to full batch
auto q_for_rope = q_reshaped->view({batch_size * seq_len, num_attention_heads_, head_dim_});
auto k_for_rope = k_reshaped->view({batch_size * seq_len, num_key_value_heads_, head_dim_});
// Call RoPE on full batch (matching Python pattern)
auto q_rope_out = rotary_emb_->forward(q_for_rope, pos_ids_for_rope);
......@@ -92,27 +86,16 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
// 5. Process each batch item separately for attention computation
infinilm::cache::KVCache *external_cache = static_cast<infinilm::cache::KVCache *>(kv_cache);
auto output_tensor = infinicore::Tensor::empty(
{batch_size, seq_len, hidden_size_},
q->dtype(),
q->device());
for (size_t b = 0; b < batch_size; ++b) {
// Extract batch item from RoPE output (already computed above for full batch)
// Ensure contiguous after narrow+view to avoid stride issues in GEMM operations
auto q_batch = q_rope_out->narrow({{0, b, 1}})->view({seq_len, num_attention_heads_, head_dim_});
auto k_batch = k_rope_out->narrow({{0, b, 1}})->view({seq_len, num_key_value_heads_, head_dim_});
auto v_batch = v_reshaped->narrow({{0, b, 1}})->view({seq_len, num_key_value_heads_, head_dim_});
// Convert to [n_head, seq_len, head_dim] for cache
// Convert to [batch, n_head, seq_len, head_dim] for cache
// Ensure contiguous after permute for F16 compatibility with cache operations
auto q_rope = q_batch->permute({1, 0, 2})->contiguous(); // [n_q_head, seq_len, head_dim]
auto k_rope = k_batch->permute({1, 0, 2})->contiguous(); // [n_kv_head, seq_len, head_dim]
auto v_permuted = v_batch->permute({1, 0, 2})->contiguous(); // [n_kv_head, seq_len, head_dim]
auto q_rope = q_rope_out->permute({0, 2, 1, 3})->contiguous(); // [bs, n_q_head, seq_len, head_dim]
auto k_rope = k_rope_out->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
// 5. Prepare KV caches
infinicore::Tensor k_total;
infinicore::Tensor v_total;
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
if (external_cache != nullptr) {
auto [k_total_tmp, v_total_tmp] = external_cache->update(k_rope, v_permuted);
k_total = k_total_tmp;
......@@ -122,78 +105,30 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
k_total = k_total_tmp;
v_total = v_total_tmp;
}
auto total_seq_len = k_total->shape()[2];
// 6. Compute attention - strictly align with Python pattern
// Python: query_states_i = query_states.narrow(0, i, 1).view((seq_len, num_attention_heads, head_dim))
// Python: key_states_i = key_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python: value_states_i = value_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python: attention_i = grouped_query_attention(query_states_i, key_states_i, value_states_i, scaling=self.scaling)
// Extract from KV cache (k_total and v_total are [n_kv_head, total_seq_len, head_dim])
// Python: key_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python's narrow+view ensures contiguous memory, so we need to ensure contiguous before permute
auto k_for_attn = k_total->permute({1, 0, 2}); // [total_seq_len, n_kv_head, head_dim]
auto v_for_attn = v_total->permute({1, 0, 2}); // [total_seq_len, n_kv_head, head_dim]
// q_batch is already [seq_len, n_q_head, head_dim] from above
auto q_for_attn = q_batch; // [seq_len, n_q_head, head_dim]
// Python: grouped_query_attention calls repeat_kv if ngroup > 1
// Python: repeat_kv expands [total_seq_len, num_key_value_heads, head_dim] -> [total_seq_len, num_attention_heads, head_dim]
// 6. Compute attention
size_t ngroup = num_attention_heads_ / num_key_value_heads_;
if (ngroup > 1) {
// Python: repeat_kv uses as_strided to expand
size_t total_seq_len = k_for_attn->shape()[0];
size_t n_kv_head = k_for_attn->shape()[1];
size_t head_dim = k_for_attn->shape()[2];
auto k_strides = k_for_attn->strides();
auto k_strided = k_for_attn->as_strided(
{total_seq_len, n_kv_head, ngroup, head_dim},
{k_strides[0], k_strides[1], 0, k_strides[2]});
k_for_attn = k_strided->contiguous()->view({total_seq_len, n_kv_head * ngroup, head_dim});
auto v_strides = v_for_attn->strides();
auto v_strided = v_for_attn->as_strided(
{total_seq_len, n_kv_head, ngroup, head_dim},
{v_strides[0], v_strides[1], 0, v_strides[2]});
v_for_attn = v_strided->contiguous()->view({total_seq_len, n_kv_head * ngroup, head_dim});
}
// Python: multi_head_attention(querys, keys, values, scaling)
// Python: Q = querys.permute((1, 0, 2)) # [num_heads, seq_len, head_dim]
// Python: K = keys # [total_seq_len, num_heads, head_dim] (NO permute!)
// Python: V = values.permute((1, 0, 2)) # [num_heads, total_seq_len, head_dim]
auto Q = q_for_attn->permute({1, 0, 2}); // [n_q_head, seq_len, head_dim]
auto K = k_for_attn; // [total_seq_len, n_q_head, head_dim] - keep as-is (matching Python)
auto V = v_for_attn->permute({1, 0, 2}); // [n_q_head, total_seq_len, head_dim]
auto Q = q_rope->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
// Python: attn_weight = Q @ K.permute((1, 2, 0))
// Python: K.permute((1, 2, 0)) transforms [total_seq_len, num_heads, head_dim] -> [num_heads, head_dim, total_seq_len]
auto K_transposed = K->permute({1, 2, 0}); // [n_q_head, head_dim, total_seq_len]
auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
// Use GEMM with alpha=scaling to combine scaling with matrix multiplication
// This is more efficient than doing matmul followed by mul
float scaling = 1.0f / std::sqrt(static_cast<float>(head_dim_));
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [n_q_head, seq_len, total_seq_len]
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [bs * n_kv_head, ng * seq_len, total_seq_len]
infinicore::op::causal_softmax_(attn_weight, attn_weight);
auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
auto out = infinicore::op::matmul(attn_weight, V); // [n_q_head, seq_len, head_dim]
auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
// Python: return out.permute((1, 0, 2)).contiguous() # [seq_len, num_heads, head_dim]
auto attn_output = out->permute({1, 0, 2})->contiguous(); // [seq_len, n_q_head, head_dim]
// Python: attn_output_i.copy_(attention_i)
// Python: attn_output = attn_output.view(hidden_states_shape) # [bs, seq_len, hidden_size]
// Copy to output tensor - attn_output is [seq_len, num_attention_heads, head_dim]
auto output_batch = output_tensor->narrow({{0, b, 1}})->view({seq_len, hidden_size_});
auto attn_flat = attn_output->contiguous()->view({seq_len, hidden_size_});
output_batch->copy_from(attn_flat);
}
auto attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
// 8. Apply output projection to all batches
auto output = o_proj_->forward(output_tensor);
auto output = o_proj_->forward(attn_output);
return output;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment