Commit ff00b5c8 authored by PanZezhong's avatar PanZezhong
Browse files

issue/125 统一Cache接口

parent 13a4154a
#pragma once
#include "../engine/distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
namespace infinilm::cache {
class Cache {
public:
Cache() = default;
virtual ~Cache() {}
};
class CacheConfig {
public:
CacheConfig() = default;
virtual ~CacheConfig() {}
virtual std::unique_ptr<CacheConfig> unique_copy() const = 0;
};
} // namespace infinilm::cache
#pragma once
#include "cache_config.hpp"
#include "base_cache.hpp"
#include "kv_cache.hpp"
#pragma once
#include <cstddef>
#include <string>
#include <cstdint>
namespace infinilm::cache {
/**
* @enum CacheType
* @brief Enumeration of supported cache types
*/
enum class CacheType {
DYNAMIC, ///< Dynamic KV cache (grows as needed)
PAGED, ///< Paged KV cache (for paged attention)
};
enum class CacheResetMode {
PRESERVE, // Keep cache memory, only reset positions
RECREATE // Recreate cache with new configuration
};
struct CacheConfig {
CacheType type = CacheType::DYNAMIC;
size_t num_layers = 0;
size_t max_kv_cache_length = SIZE_MAX;
size_t initial_capacity = 1024; // Initial cache capacity in tokens
size_t initial_batch_size = 1; // Initial batch size for cache allocation
float growth_factor = 2.0f; // Cache growth factor when resizing
bool allow_expand = true; // Whether to allow cache expansion
CacheResetMode reset_mode = CacheResetMode::PRESERVE;
// Constructor
CacheConfig() = default;
CacheConfig(CacheType type, size_t num_layers = 32, size_t max_kv_cache_length = 4096)
: type(type), num_layers(num_layers), max_kv_cache_length(max_kv_cache_length) {}
bool operator==(const CacheConfig &other) const {
return type == other.type && num_layers == other.num_layers && max_kv_cache_length == other.max_kv_cache_length && initial_capacity == other.initial_capacity && initial_batch_size == other.initial_batch_size && growth_factor == other.growth_factor;
}
bool operator!=(const CacheConfig &other) const {
return !(*this == other);
}
};
} // namespace infinilm::cache
#include "kv_cache.hpp"
#include "../utils.hpp"
#include <stdexcept>
namespace infinilm::cache {
// ==========================
// StaticKVCache
// ==========================
StaticKVCache::StaticKVCache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::Size max_positional_embedding,
infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
k_dim_(k_dim),
v_dim_(v_dim),
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_batch_size_(config.max_batch_size()),
cache_len_(std::min(config.max_cache_len(), max_positional_embedding)),
rank_num_layers_(num_layers),
dtype_(dtype) {
// Allocate K cache
k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
rank_batch_size_,
num_rank_k_heads_,
cache_len_,
k_dim_},
dtype_,
rank_info.device);
// Allocate V cache
v_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
rank_batch_size_,
num_rank_v_heads_,
cache_len_,
v_dim_},
dtype_,
rank_info.device);
spdlog::info("Created Static KV Cache: K[{}] V[{}]", k_caches_->info(), v_caches_->info());
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache::update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &cache_positions) {
ASSERT(layer_idx < rank_num_layers_);
auto batch_size = k->size(0);
auto update_len = k->size(2);
size_t cache_pos = reinterpret_cast<int64_t *>(cache_positions->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
ASSERT_EQ(batch_size, rank_batch_size_);
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
k_cache_update->copy_from(k);
v_cache_update->copy_from(v);
auto k_total = k_cache_layer->narrow({{2, 0, result_len}});
auto v_total = v_cache_layer->narrow({{2, 0, result_len}});
return {k_total, v_total};
}
// ==========================
// StaticKVCacheConfig
// ==========================
StaticKVCacheConfig::StaticKVCacheConfig(
infinicore::Size _max_batch_size,
infinicore::Size _max_cache_len)
: max_batch_size_(_max_batch_size),
max_cache_len_(_max_cache_len) {
}
std::unique_ptr<CacheConfig>
StaticKVCacheConfig::unique_copy() const {
return std::make_unique<StaticKVCacheConfig>(*this);
}
infinicore::Size
StaticKVCacheConfig::max_batch_size() const {
return max_batch_size_;
}
infinicore::Size
StaticKVCacheConfig::max_cache_len() const {
return max_cache_len_;
}
} // namespace infinilm::cache
#pragma once
#include "base_cache.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "cache_config.hpp"
#include <algorithm>
#include <limits>
#include <memory>
#include <numeric>
#include <stdexcept>
......@@ -15,355 +16,70 @@
#include <spdlog/spdlog.h>
namespace infinilm::cache {
class StaticKVCacheConfig final : public CacheConfig {
public:
StaticKVCacheConfig(
infinicore::Size _max_batch_size = 1,
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());
/**
* @brief Single layer's KV cache for incremental decoding
*
* Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim]
* Similar to DynamicLayer in Python cache_utils.py
*
* This represents a single layer's cache within a model-level cache container.
*/
struct KVCacheLayer {
infinicore::Tensor k_cache; // [batch_size, n_kv_head, capacity, head_dim]
infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim]
std::vector<size_t> cache_positions; // Current position in cache
size_t max_capacity; // Maximum capacity of cache
size_t initial_capacity; // Initial capacity from config
size_t initial_batch_size; // Initial batch size from config
float growth_factor; // Growth factor for dynamic resizing
bool initialized; // Whether cache has been initialized
KVCacheLayer() : max_capacity(0), initial_capacity(4096), initial_batch_size(1),
growth_factor(2.0f), initialized(false) {}
/**
* @brief Initialize or update cache capacity with config parameters
* @param batch_size Current batch size
* @param num_kv_heads Number of key-value heads
* @param head_dim Head dimension
* @param seq_len Sequence length of new tokens
* @param dtype Data type
* @param device Device
* @param cache_config Cache configuration parameters
*/
void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len,
infinicore::DataType dtype, const infinicore::Device &device,
const CacheConfig &cache_config) {
size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); });
// VALIDATION: Verify input parameters
if (num_kv_heads == 0 || head_dim == 0 || seq_len == 0) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Invalid parameters - num_kv_heads: {}, head_dim: {}, seq_len: {}",
num_kv_heads, head_dim, seq_len);
throw std::runtime_error("KV cache ensure_capacity: invalid parameters");
}
// Store config parameters on first initialization
if (!initialized) {
initial_capacity = cache_config.initial_capacity;
initial_batch_size = cache_config.initial_batch_size;
growth_factor = cache_config.growth_factor;
}
// Lazy initialization
if (!initialized) {
// Use max of required capacity and initial capacity from config
max_capacity = std::max(required_capacity, initial_capacity);
// Use max of current batch size and initial batch size from config
size_t alloc_batch_size = std::max(batch_size, initial_batch_size);
k_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device);
v_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device);
cache_positions = std::vector<size_t>(alloc_batch_size, 0);
initialized = true;
spdlog::debug("Initialized KV cache with batch_size={}, capacity={} (config: initial_batch={}, initial_capacity={})",
alloc_batch_size, max_capacity, initial_batch_size, initial_capacity);
// VALIDATION: Verify cache was created correctly
if (k_cache->shape()[0] != alloc_batch_size || k_cache->shape()[1] != num_kv_heads || k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization");
throw std::runtime_error("KV cache initialization: shape mismatch");
}
}
// Grow cache if needed using growth factor from config
else if (required_capacity > max_capacity) {
if (!cache_config.allow_expand) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache expansion not allowed by config");
throw std::runtime_error("KV cache expansion not allowed");
}
// Calculate new capacity using growth factor
size_t new_capacity = static_cast<size_t>(
std::max(static_cast<float>(max_capacity) * growth_factor,
static_cast<float>(required_capacity + max_capacity)));
// Ensure we don't exceed max_position_embeddings if specified
if (cache_config.max_kv_cache_length != 0) {
new_capacity = std::min(new_capacity, cache_config.max_kv_cache_length);
}
// Ensure we grow by at least some minimum amount
size_t min_growth = 256;
if (new_capacity - max_capacity < min_growth) {
new_capacity = max_capacity + min_growth;
}
size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]);
if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) {
throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache.");
}
if (new_batch_size > cache_positions.size()) {
cache_positions.resize(new_batch_size, 0);
}
auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device);
auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device);
spdlog::debug("Growing KV cache from capacity {} to {} (growth_factor={})",
max_capacity, new_capacity, growth_factor);
// Copy existing cache data
for (size_t b = 0; b < new_batch_size; ++b) {
size_t cache_position = cache_positions[b];
if (cache_position > 0) {
auto k_slice = k_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
auto v_slice = v_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
k_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(k_slice);
v_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(v_slice);
}
}
k_cache = k_new;
v_cache = v_new;
max_capacity = new_capacity;
// VALIDATION: Verify cache was grown correctly
if (k_cache->shape()[2] != new_capacity) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch");
throw std::runtime_error("KV cache growth: capacity mismatch");
}
}
// VALIDATION: Final check that capacity is sufficient
if (required_capacity > max_capacity) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth");
throw std::runtime_error("KV cache ensure_capacity: capacity insufficient");
}
}
/**
* @brief Update cache with new key and value states
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @param cache_config Cache configuration for capacity management
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
*/
std::pair<infinicore::Tensor, infinicore::Tensor> update(
const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new,
const CacheConfig &cache_config) {
if (k_new->ndim() != 4 || v_new->ndim() != 4) {
throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors");
}
size_t batch_size = k_new->shape()[0];
size_t num_kv_heads = k_new->shape()[1];
size_t seq_len = k_new->shape()[2];
size_t head_dim = k_new->shape()[3];
// Ensure capacity with cache config
ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len,
k_new->dtype(), k_new->device(), cache_config);
// Copy new k/v into cache at current position
bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin());
if (all_equal) {
auto cache_position = cache_positions[0];
auto k_dst = k_cache->narrow({{2, cache_position, seq_len}});
auto v_dst = v_cache->narrow({{2, cache_position, seq_len}});
k_dst->copy_from(k_new);
v_dst->copy_from(v_new);
// Update position
cache_position += seq_len;
for (size_t b = 0; b < batch_size; ++b) {
cache_positions[b] = cache_position;
}
// Return the total cache up to current position
auto k_total = k_cache->narrow({{2, 0, cache_position}});
auto v_total = v_cache->narrow({{2, 0, cache_position}});
std::unique_ptr<CacheConfig> unique_copy() const override;
infinicore::Size max_batch_size() const;
infinicore::Size max_cache_len() const;
return std::make_pair(k_total, v_total);
} else {
throw std::runtime_error("KVCache update: cache positions must be equal among a batch.");
}
}
private:
infinicore::Size max_batch_size_;
infinicore::Size max_cache_len_;
};
/**
* @brief Model-level KV cache container (similar to DynamicCache in Python)
*
* Stores a list of KVCacheLayer objects, one per model layer.
* This aligns with Python backend's DynamicCache architecture.
*/
class DynamicCache {
class StaticKVCache final : public Cache {
public:
/**
* @brief Construct DynamicCache with cache configuration
* @param cache_config Cache configuration parameters
*/
DynamicCache(const CacheConfig &cache_config)
: cache_config_(cache_config), layers_(cache_config.num_layers) {
if (cache_config.num_layers == -1) {
throw std::runtime_error("DynamicCache: num_layers must be specified in CacheConfig");
}
}
/**
* @brief Construct DynamicCache with specified number of layers
*
* @param num_layers Number of model layers (creates one cache layer per model layer)
* @param max_position_embeddings Maximum position embeddings (used for initial capacity)
*/
DynamicCache(size_t num_layers, size_t max_position_embeddings = 4096)
: cache_config_(CacheConfig(CacheType::DYNAMIC, num_layers, max_position_embeddings)), layers_(num_layers) {}
/**
* @brief Update cache with new key and value states for a specific layer
*/
std::pair<infinicore::Tensor, infinicore::Tensor> update(
size_t layer_idx,
const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new) {
if (layer_idx >= layers_.size()) {
SPDLOG_ERROR("DynamicCache::update: layer_idx {} out of range (num_layers: {})",
layer_idx, layers_.size());
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
// Update the cache for this layer with cache config
return layers_[layer_idx].update(k_new, v_new, cache_config_);
}
StaticKVCache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::Size max_positional_embedding,
infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);
/**
* @brief Update cache with new key and value states (convenience method without layer_idx)
* This is used when the cache is accessed directly without layer information
* @brief Update KV cache at a given layer and cache position.
*
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
* @param layer_idx Which transformer layer
* @param k [batch, num_rank_k_heads, seq_len, k_dim]
* @param v [batch, num_rank_v_heads, seq_len, v_dim]
* @param cache_pos Sequence position to write
*
* Note: This assumes layer_idx=0. For multi-layer models, use update(layer_idx, k_new, v_new) instead.
*/
std::pair<infinicore::Tensor, infinicore::Tensor> update(
const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new) {
return update(0, k_new, v_new);
}
/**
* @brief Get cache configuration
*/
const CacheConfig &get_config() const { return cache_config_; }
/**
* @brief Update cache configuration (for dynamic reconfiguration)
*/
void update_config(const CacheConfig &new_config) {
// Check if we need to rebuild
bool need_rebuild = false;
// Rebuild if number of layers changed
if (new_config.num_layers != cache_config_.num_layers || new_config.initial_batch_size != cache_config_.initial_batch_size) {
need_rebuild = true;
layers_.resize(new_config.num_layers);
}
// Rebuild if reset mode is RECREATE
if (new_config.reset_mode == CacheResetMode::RECREATE) {
need_rebuild = true;
}
// Update configuration
cache_config_ = new_config;
if (need_rebuild) {
// Clear all layers to force reinitialization on next use
for (auto &layer : layers_) {
layer.initialized = false;
layer.max_capacity = 0;
// Tensors will be recreated when ensure_capacity is called
}
spdlog::info("DynamicCache configuration updated - cache will be rebuilt on next use");
} else {
spdlog::info("DynamicCache configuration updated: layers={}, initial_capacity={}, growth_factor={}",
new_config.num_layers, new_config.initial_capacity, new_config.growth_factor);
}
}
/**
* @brief Get the number of layers in this cache
*/
size_t num_layers() const { return layers_.size(); }
/**
* @brief Get cache position for a specific layer
*/
size_t cache_position(size_t layer_idx) const {
if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
if (layers_[layer_idx].cache_positions.empty()) {
return 0;
}
return layers_[layer_idx].cache_positions[0]; // All batch items should have same position
}
/**
* @brief Get max position embeddings (used for initial capacity)
*/
size_t max_kv_cache_length() const { return cache_config_.max_kv_cache_length; }
/**
* @brief Reset cache for all layers to a specific position
* This should be called when starting a new generation sequence or resetting to a specific position
* @param pos Position to reset to (defaults to 0)
*/
void reset(size_t pos = 0) {
for (auto &layer : layers_) {
std::fill(layer.cache_positions.begin(), layer.cache_positions.end(), pos);
// Note: We don't reset initialized flag or clear the cache tensors
// to avoid reallocation. The cache will be overwritten on next update.
}
}
/**
* @brief Access a specific layer's cache (for advanced usage)
* @return (full_k, full_v)
* full_k: [batch, num_rank_k_heads, cache_pos + seq_len, k_dim]
* full_v: [batch, num_rank_v_heads, cache_pos + seq_len, v_dim]
*/
KVCacheLayer &layer(size_t layer_idx) {
if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
return layers_[layer_idx];
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &cache_positions);
const KVCacheLayer &layer(size_t layer_idx) const {
if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
return layers_[layer_idx];
}
~StaticKVCache() override = default;
private:
CacheConfig cache_config_;
std::vector<KVCacheLayer> layers_;
infinicore::Size k_dim_;
infinicore::Size v_dim_;
infinicore::Size num_rank_k_heads_;
infinicore::Size num_rank_v_heads_;
infinicore::Size rank_batch_size_;
infinicore::Size cache_len_;
infinicore::Size rank_num_layers_;
infinicore::DataType dtype_;
// [num_layers, max_batch, num_rank_k_heads, max_cache_len, k_dim]
infinicore::Tensor k_caches_;
// [num_layers, max_batch, num_rank_v_heads, max_cache_len, v_dim]
infinicore::Tensor v_caches_;
};
} // namespace infinilm::cache
......@@ -38,7 +38,7 @@ int CommunicationGroup::get_world_size() const {
CommunicationGroup::~CommunicationGroup() {
if (communicators_.size() > 1) {
for (auto &comm : communicators_) {
RUN_INFINI(infinicclCommDestroy(comm));
infinicclCommDestroy(comm);
}
}
}
......
......@@ -10,32 +10,13 @@ InferEngine::InferEngine(
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig &cache_config) // Changed parameter
const cache::CacheConfig *cache_config) // Changed parameter
: communication_group_(distributed_config, device_type),
model_config_(config),
cache_config_(cache_config) {
model_config_(config) {
spdlog::info("Launch InferEngine with {}", std::string(distributed_config));
spdlog::info("Cache configuration: type={}, layers={}, max_kv_cache_length={}",
static_cast<int>(cache_config_.type),
cache_config_.num_layers,
cache_config_.max_kv_cache_length);
// Try to extract model configuration to override default cache parameters if needed
try {
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr;
cache_config_.num_layers = llama_config.num_hidden_layers;
cache_config_.max_kv_cache_length = llama_config.max_position_embeddings;
spdlog::info("Updated cache config from model: layers={}, max_kv_cache_length={}",
cache_config_.num_layers, cache_config_.max_kv_cache_length);
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}
} catch (...) {
spdlog::warn("Could not extract model config, using provided CacheConfig");
}
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
workers_.reserve(world_size);
......@@ -43,7 +24,7 @@ InferEngine::InferEngine(
workers_.emplace_back(std::make_unique<RankWorker>(
model_config_,
communication_group_.get_rank_info(r),
cache_config_));
cache_config_ != nullptr ? cache_config_.get() : nullptr));
}
}
......@@ -75,12 +56,14 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------
// forward
//------------------------------------------------------
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
const auto &[input_ids, position_ids] = input;
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const {
return {input_ids, position_ids, cache_positions};
}
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
// Trigger each worker to run inference
for (auto &worker : workers_) {
worker->run({input_ids, position_ids});
worker->run(input.to_model_input());
}
// Wait for all workers
for (auto &worker : workers_) {
......@@ -104,25 +87,12 @@ const distributed::DistConfig &InferEngine::get_dist_config() const {
return communication_group_.get_dist_config();
}
//------------------------------------------------------
// reset_cache
//------------------------------------------------------
void InferEngine::reset_cache(size_t pos) {
for (auto &worker : workers_) {
worker->reset_cache(pos);
}
for (auto &worker : workers_) {
worker->wait();
}
}
//------------------------------------------------------
// reset_cache (overloaded with CacheConfig)
//------------------------------------------------------
void InferEngine::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
cache_config_ = new_config;
void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
for (auto &worker : workers_) {
worker->reset_cache(new_config, pos);
worker->reset_cache(new_config);
}
for (auto &worker : workers_) {
worker->wait();
......
......@@ -4,8 +4,8 @@
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"
#include "../models/infinilm_model.hpp"
#include <any>
#include <vector>
namespace infinilm::engine {
......@@ -16,6 +16,10 @@ public:
infinicore::Tensor input_ids;
infinicore::Tensor position_ids;
infinicore::Tensor cache_positions;
infinilm::InfinilmModel::Input to_model_input() const;
};
struct Output {
......@@ -27,7 +31,7 @@ public:
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig &cache_config = cache::CacheConfig());
const cache::CacheConfig *cache_config = nullptr);
// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
......@@ -38,24 +42,20 @@ public:
// Run a single forward pass on all workers and return the outputs from all ranks
Output forward(const Input &input);
// Reset the internal cache pos in all workers (clears state between generations)
void reset_cache(size_t pos = 0);
// Overload: reset cache with new KV configuration
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0);
void reset_cache(const cache::CacheConfig *new_config);
~InferEngine();
const distributed::DistConfig &get_dist_config() const;
// Get current KV configuration
const cache::CacheConfig &get_cache_config() const { return cache_config_; }
const cache::CacheConfig *get_cache_config() const { return cache_config_.get(); }
protected:
std::vector<std::unique_ptr<RankWorker>> workers_;
distributed::CommunicationGroup communication_group_;
const InfinilmModel::Config &model_config_;
cache::CacheConfig cache_config_;
std::unique_ptr<cache::CacheConfig> cache_config_;
};
} // namespace infinilm::engine
......@@ -10,15 +10,17 @@ namespace infinilm::engine {
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config)
const cache::CacheConfig *cache_config)
: model_config_(model_config),
rank_info_(rank_info),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
should_exit_(false),
init_done_(false),
pending_cache_config_(cache_config) {
init_done_(false) {
if (cache_config != nullptr) {
pending_cache_config_ = cache_config->unique_copy();
}
// start the thread
thread_ = std::thread(&RankWorker::thread_loop, this);
......@@ -80,7 +82,14 @@ void RankWorker::load_param(const std::string &name,
// state_dict --
//------------------------------------------------------
std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dict() {
return this->model_->state_dict();
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_ || should_exit_; });
if (!model_) {
throw std::runtime_error("state_dict called before model initialization");
}
return model_->state_dict();
}
//------------------------------------------------------
......@@ -113,32 +122,15 @@ void RankWorker::wait() {
}
}
//------------------------------------------------------
// reset_cache -- synchronous by default, async optional (unstable)
//------------------------------------------------------
void RankWorker::reset_cache(size_t pos) {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot reset_cache");
}
pending_reset_pos_ = pos;
job_cmd_ = Command::RESET_CACHE;
has_job_ = true;
job_done_ = false;
cv_.notify_all();
}
void RankWorker::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot reset_cache");
}
// Store both the position and the new config
pending_reset_pos_ = pos;
pending_cache_config_ = new_config;
job_cmd_ = Command::RESET_CACHE_WITH_CONFIG;
pending_cache_config_ = new_config->unique_copy();
job_cmd_ = Command::RESET_CACHE;
has_job_ = true;
job_done_ = false;
cv_.notify_all();
......@@ -174,17 +166,17 @@ InfinilmModel::Output RankWorker::get_output() {
//------------------------------------------------------
void RankWorker::thread_loop() {
try {
{
std::lock_guard<std::mutex> lk(mutex_);
// Initialize device & model outside of holding the main mutex to avoid blocking callers.
infinicore::context::setDevice(rank_info_.device);
cache_ptr_ = std::make_shared<cache::DynamicCache>(pending_cache_config_);
// Create model using factory (may be expensive)
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, cache_ptr_);
// Signal that initialization is done
{
std::lock_guard<std::mutex> lk(mutex_);
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
if (!model_) {
throw std::runtime_error("Failed to create model");
}
init_done_ = true;
}
cv_.notify_all();
......@@ -195,8 +187,7 @@ void RankWorker::thread_loop() {
std::string local_param_name;
infinicore::Tensor local_param;
InfinilmModel::Input local_args;
size_t local_reset_pos = 0;
cache::CacheConfig local_reset_config;
std::unique_ptr<cache::CacheConfig> local_cache_config;
// Wait for a job or exit
{
......@@ -215,12 +206,10 @@ void RankWorker::thread_loop() {
} else if (local_cmd == Command::RUN) {
local_args = pending_args_;
} else if (local_cmd == Command::RESET_CACHE) {
local_reset_pos = pending_reset_pos_;
} else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) {
local_reset_pos = pending_reset_pos_;
local_reset_config = pending_cache_config_;
if (pending_cache_config_ != nullptr) {
local_cache_config = pending_cache_config_->unique_copy();
}
}
// mark job as being processed
has_job_ = false;
job_done_ = false;
......@@ -270,14 +259,7 @@ void RankWorker::thread_loop() {
}
} else if (local_cmd == Command::RESET_CACHE) {
try {
// Option 1: Use model's reset_cache if it handles cache
model_->reset_cache(local_reset_pos);
// Option 2: Reset cache directly if we have access
// if (cache_ptr_ != nullptr) {
// auto* dynamic_cache = static_cast<cache::DynamicCache*>(cache_ptr_);
// dynamic_cache->reset(local_reset_pos);
// }
model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr);
{
std::lock_guard<std::mutex> lk(mutex_);
......@@ -293,25 +275,6 @@ void RankWorker::thread_loop() {
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
break;
}
} else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) {
try {
// Use model's reset_cache with new configuration
model_->reset_cache(local_reset_config, local_reset_pos);
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
}
cv_.notify_all();
} catch (const std::exception &e) {
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
cv_.notify_all();
spdlog::error("[{}] exception during reset_cache with config: {}\n", info(), e.what());
break;
}
} else {
// Shouldn't reach here (no-op)
}
......
......@@ -19,14 +19,13 @@ class RankWorker {
LOAD,
RUN,
RESET_CACHE,
RESET_CACHE_WITH_CONFIG,
STOP
};
public:
RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config);
const cache::CacheConfig *cache_config);
// Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name,
......@@ -38,11 +37,8 @@ public:
// Submit a run (forward) job.
void run(const InfinilmModel::Input &args);
// Reset the internal cache in the model (clears state between generations)
void reset_cache(size_t pos = 0);
// Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0);
void reset_cache(const cache::CacheConfig *new_config);
// Wait until run job completes. The result can be retrieved with get_output().
void wait();
......@@ -63,7 +59,7 @@ private:
const InfinilmModel::Config &model_config_;
distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::DynamicCache> cache_ptr_;
std::shared_ptr<cache::Cache> cache_;
// Command for the pending job (protected by mutex_)
Command job_cmd_;
......@@ -78,8 +74,7 @@ private:
std::string pending_param_name_;
infinicore::Tensor pending_param_;
InfinilmModel::Input pending_args_;
size_t pending_reset_pos_ = 0;
cache::CacheConfig pending_cache_config_;
std::unique_ptr<cache::CacheConfig> pending_cache_config_;
// Output (protected by mutex)
InfinilmModel::Output output_;
......
......@@ -18,12 +18,10 @@ public:
struct Input {
/// Token IDs tensor of shape `[batch, seq_len]`.
infinicore::Tensor input_ids;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
infinicore::Tensor position_ids;
/// Optional model-level KV cache for incremental decoding. Defaults to `nullptr`.
void *kv_cache = nullptr;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
infinicore::Tensor cache_positions;
};
struct Output {
......@@ -33,8 +31,7 @@ public:
virtual ~InfinilmModel() = default;
virtual Output forward(const Input &input) const = 0;
// Optional: reset cache; default no-op for models without cache
virtual void reset_cache(size_t pos = 0) {}
virtual void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) = 0;
virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
};
} // namespace infinilm
......@@ -51,7 +51,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
void *kv_cache) const {
std::shared_ptr<cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const {
if (!rotary_emb_) {
throw std::runtime_error("LlamaAttention: rotary_emb not configured");
}
......@@ -97,16 +98,15 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim]
auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
infinilm::cache::DynamicCache *external_cache = static_cast<infinilm::cache::DynamicCache *>(kv_cache);
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
if (external_cache != nullptr) {
auto [k_total_tmp, v_total_tmp] = external_cache->update(layer_idx_, k_permuted, v_permuted);
if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) {
auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_positions);
k_total = k_total_tmp;
v_total = v_total_tmp;
} else {
// No external cache - this shouldn't happen in normal operation, but handle gracefully
throw std::runtime_error("LlamaAttention: kv_cache is required but nullptr provided");
throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
}
auto total_seq_len = k_total->shape()[2];
......
......@@ -50,7 +50,8 @@ public:
*/
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const;
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const;
/**
* @brief Get the layer index
......
......@@ -23,7 +23,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
void *kv_cache) const {
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const {
// Save residual for attention
auto residual = hidden_states;
......@@ -31,7 +32,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
auto normed_states = input_layernorm_->forward(hidden_states);
// 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache);
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_positions);
// Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output);
......
......@@ -48,7 +48,8 @@ public:
*/
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const;
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const;
/**
* @brief Get the layer index
......
......@@ -26,11 +26,11 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
}
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
const auto &[input_ids, position_ids, kv_cache] = input;
const auto &[input_ids, position_ids, cache_position] = input;
// 1. Forward through base model to get hidden states
auto position_ids_device = position_ids->to(device_);
auto hidden_states = model_->forward(input_ids, position_ids_device, kv_cache);
auto hidden_states = model_->forward(input_ids, position_ids_device, cache_position);
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
......@@ -38,12 +38,8 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
return {logits};
}
void LlamaForCausalLM::reset_cache(size_t pos) {
model_->reset_cache(pos);
}
void LlamaForCausalLM::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
model_->reset_cache(new_config, pos);
void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) {
model_->reset_cache(cache_config);
}
} // namespace infinilm::models::llama
......@@ -40,9 +40,7 @@ public:
*/
Output forward(const Input &input) const;
// Reset internal cache position
void reset_cache(size_t pos = 0) override;
void reset_cache(const cache::CacheConfig &new_config, size_t pos) override;
void reset_cache(const cache::CacheConfig *cache_config) override;
// Module information
const LlamaConfig &config() const { return model_->config(); }
......
......@@ -10,9 +10,8 @@ namespace infinilm::models::llama {
LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: config_(config) {
: config_(config), rank_info_(rank_info) {
const auto &dtype{config.dtype};
// Initialize token embeddings
INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size,
std::nullopt, dtype, device);
......@@ -46,72 +45,46 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
void *kv_cache) const {
// Use persistent internal cache if no external cache is provided
// This matches Python backend behavior: if use_cache and past_key_values is None, create DynamicCache
// The cache persists across forward calls to enable incremental decoding
void *cache_to_use = kv_cache;
if (cache_to_use == nullptr) {
// Create or reuse persistent internal cache at model level
// This ensures the cache persists across multiple forward calls (prefill -> decode -> decode...)
if (external_cache_ != nullptr) {
cache_to_use = external_cache_;
} else {
// Fall back to internal cache
if (!internal_cache_) {
internal_cache_ = std::make_unique<infinilm::cache::DynamicCache>(
config_.num_hidden_layers,
config_.max_position_embeddings);
}
cache_to_use = internal_cache_.get();
}
}
const infinicore::Tensor &cache_positions) const {
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
auto hidden_states = embed_tokens_->forward(input_ids);
// 2. Process through all decoder layers
size_t num_layers = layers_.size();
for (size_t i = 0; i < num_layers; ++i) {
// Pass model-level cache (layer index is now a property of the layer)
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, cache_to_use);
// DEBUG: Disabled previous final layer logging
// Logging moved to decoder layer for post-attention normalization
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_positions);
}
// 3. Apply final layer normalization to last token only (aligns with transformers)
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
auto shape = hidden_states->shape();
size_t seq_len = shape[1];
auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}});
// DEBUG: Disabled previous final layer normalization logging
// Normalize only the last token (matches Python backend)
auto normalized_last_token = norm_->forward(last_token);
return normalized_last_token;
}
void LlamaModel::reset_cache(size_t pos) const {
if (internal_cache_) {
internal_cache_->reset(pos);
void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
if (cache_config == nullptr) {
kv_cache_ = nullptr;
return;
}
if (external_cache_) {
external_cache_->reset(pos);
}
}
if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
kv_cache_ = std::make_shared<cache::StaticKVCache>(
config_.head_dim,
config_.head_dim,
config_.num_key_value_heads,
config_.num_key_value_heads,
config_.num_hidden_layers,
config_.max_position_embeddings,
config_.dtype,
*kv_cache_config,
rank_info_);
void LlamaModel::reset_cache(const cache::CacheConfig &new_config, size_t pos) const {
if (internal_cache_) {
internal_cache_->update_config(new_config);
internal_cache_->reset(pos);
}
if (external_cache_) {
external_cache_->update_config(new_config);
external_cache_->reset(pos);
} else {
throw std::runtime_error("Unsupported cache type");
}
}
......
......@@ -47,41 +47,19 @@ public:
*
* @param input_ids Token IDs tensor of shape [batch, seq_len]
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional model-level KV cache for incremental decoding
* @param cache_positions Cache positions tensor of shape [n_req]
* @return Output tensor of shape [batch, seq_len, hidden_size]
*/
infinicore::Tensor forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const;
const infinicore::Tensor &cache_positions) const;
void reset_cache(const cache::CacheConfig *cache_config);
// Module information
const LlamaConfig &config() const { return config_; }
size_t num_layers() const { return config_.num_hidden_layers; }
/**
* @brief Reset the internal cache to a specific position
* This should be called when starting a new generation sequence to prevent state
* from persisting between different questions/prompts
* @param pos Position to reset to (defaults to 0)
*/
void reset_cache(size_t pos = 0) const;
/**
* @brief Reset the internal cache with a new configuration and position
* This should be called when changing cache parameters (e.g., initial capacity)
* @param new_config New cache configuration
* @param pos Position to reset to
*/
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) const;
/**
* @brief Set external cache for the model
* @param cache Pointer to external cache (managed by CacheManager)
*/
void set_external_cache(std::shared_ptr<cache::DynamicCache> cache) {
external_cache_ = cache.get();
}
protected:
// Token embeddings
INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens);
......@@ -95,13 +73,12 @@ protected:
// Rotary Position Embeddings (shared across all layers)
INFINICORE_NN_MODULE(infinicore::nn::RoPE, rotary_emb);
engine::distributed::RankInfo rank_info_;
std::shared_ptr<cache::Cache> kv_cache_;
private:
LlamaConfig config_;
// Persistent cache for when no external cache is provided
// Mutable because it's not part of the model's learned parameters,
// but needs to persist across forward calls for incremental decoding
mutable std::unique_ptr<infinilm::cache::DynamicCache> internal_cache_;
cache::DynamicCache *external_cache_ = nullptr;
};
} // namespace infinilm::models::llama
......@@ -5,20 +5,21 @@ namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info,
std::shared_ptr<cache::DynamicCache> cache_ptr) {
const cache::CacheConfig *cache) {
std::shared_ptr<InfinilmModel> model;
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr;
auto model = std::make_shared<models::llama::LlamaForCausalLM>(
model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, rank_info);
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
if (cache_ptr != nullptr) {
model->model().set_external_cache(cache_ptr);
if (cache) {
model->reset_cache(cache);
}
return model;
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
}
} // namespace infinilm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment