Unverified Commit 6498332e authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #124 from InfiniTensor/issue/121

Issue/121 - cache managements
parent 295aacd1
#pragma once
#include "cache_config.hpp"
#include "kv_cache.hpp"
#pragma once
#include <cstddef>
#include <string>
namespace infinilm::cache {
/**
* @enum CacheType
* @brief Enumeration of supported cache types
*/
enum class CacheType {
DYNAMIC, ///< Dynamic KV cache (grows as needed)
PAGED, ///< Paged KV cache (for paged attention)
};
enum class CacheResetMode {
PRESERVE, // Keep cache memory, only reset positions
RECREATE // Recreate cache with new configuration
};
struct CacheConfig {
CacheType type = CacheType::DYNAMIC;
size_t num_layers = 0;
size_t max_kv_cache_length = SIZE_MAX;
size_t initial_capacity = 1024; // Initial cache capacity in tokens
size_t initial_batch_size = 1; // Initial batch size for cache allocation
float growth_factor = 2.0f; // Cache growth factor when resizing
bool allow_expand = true; // Whether to allow cache expansion
CacheResetMode reset_mode = CacheResetMode::PRESERVE;
// Constructor
CacheConfig() = default;
CacheConfig(CacheType type, size_t num_layers = 32, size_t max_kv_cache_length = 4096)
: type(type), num_layers(num_layers), max_kv_cache_length(max_kv_cache_length) {}
bool operator==(const CacheConfig &other) const {
return type == other.type && num_layers == other.num_layers && max_kv_cache_length == other.max_kv_cache_length && initial_capacity == other.initial_capacity && initial_batch_size == other.initial_batch_size && growth_factor == other.growth_factor;
}
bool operator!=(const CacheConfig &other) const {
return !(*this == other);
}
};
} // namespace infinilm::cache
......@@ -4,7 +4,8 @@
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/context/context.hpp"
#include "cache_config.hpp"
#include <algorithm>
#include <memory>
#include <numeric>
......@@ -15,7 +16,6 @@
namespace infinilm::cache {
/**
* @brief Single layer's KV cache for incremental decoding
*
......@@ -29,22 +29,27 @@ struct KVCacheLayer {
infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim]
std::vector<size_t> cache_positions; // Current position in cache
size_t max_capacity; // Maximum capacity of cache
size_t initial_capacity; // Initial capacity from config
size_t initial_batch_size; // Initial batch size from config
float growth_factor; // Growth factor for dynamic resizing
bool initialized; // Whether cache has been initialized
KVCacheLayer() : max_capacity(0), initialized(false) {}
KVCacheLayer() : max_capacity(0), initial_capacity(4096), initial_batch_size(1),
growth_factor(2.0f), initialized(false) {}
/**
* @brief Initialize or update cache capacity
* @brief Initialize or update cache capacity with config parameters
* @param batch_size Current batch size
* @param num_kv_heads Number of key-value heads
* @param head_dim Head dimension
* @param seq_len Sequence length of new tokens
* @param dtype Data type
* @param device Device
* @param max_position_embeddings Maximum position embeddings (for initial capacity)
* @param cache_config Cache configuration parameters
*/
void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len,
infinicore::DataType dtype, const infinicore::Device &device,
size_t max_position_embeddings = 4096) {
const CacheConfig &cache_config) {
size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); });
// VALIDATION: Verify input parameters
......@@ -54,28 +59,59 @@ struct KVCacheLayer {
throw std::runtime_error("KV cache ensure_capacity: invalid parameters");
}
// Store config parameters on first initialization
if (!initialized) {
initial_capacity = cache_config.initial_capacity;
initial_batch_size = cache_config.initial_batch_size;
growth_factor = cache_config.growth_factor;
}
// Lazy initialization
if (!initialized) {
max_capacity = std::max(required_capacity, size_t(4096)); // Start with at least 4096
k_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim},
// Use max of required capacity and initial capacity from config
max_capacity = std::max(required_capacity, initial_capacity);
// Use max of current batch size and initial batch size from config
size_t alloc_batch_size = std::max(batch_size, initial_batch_size);
k_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device);
v_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim},
v_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device);
cache_positions = std::vector<size_t>(batch_size, 0);
cache_positions = std::vector<size_t>(alloc_batch_size, 0);
initialized = true;
spdlog::debug("Initialized KV cache with batch_size={}, capacity={} (config: initial_batch={}, initial_capacity={})",
alloc_batch_size, max_capacity, initial_batch_size, initial_capacity);
// VALIDATION: Verify cache was created correctly
// Shape is [batch_size, num_kv_heads, max_capacity, head_dim]
if (k_cache->shape()[0] != batch_size || k_cache->shape()[1] != num_kv_heads ||
k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization - expected: [{}, {}, {}, {}], got: {}",
batch_size, num_kv_heads, max_capacity, head_dim, k_cache->info());
if (k_cache->shape()[0] != alloc_batch_size || k_cache->shape()[1] != num_kv_heads || k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization");
throw std::runtime_error("KV cache initialization: shape mismatch");
}
}
// Grow cache if needed (similar to DynamicLayer in Python)
// Grow cache if needed using growth factor from config
else if (required_capacity > max_capacity) {
size_t new_capacity = std::max(max_capacity * 2, required_capacity + max_capacity);
if (!cache_config.allow_expand) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache expansion not allowed by config");
throw std::runtime_error("KV cache expansion not allowed");
}
// Calculate new capacity using growth factor
size_t new_capacity = static_cast<size_t>(
std::max(static_cast<float>(max_capacity) * growth_factor,
static_cast<float>(required_capacity + max_capacity)));
// Ensure we don't exceed max_position_embeddings if specified
if (cache_config.max_kv_cache_length != 0) {
new_capacity = std::min(new_capacity, cache_config.max_kv_cache_length);
}
// Ensure we grow by at least some minimum amount
size_t min_growth = 256;
if (new_capacity - max_capacity < min_growth) {
new_capacity = max_capacity + min_growth;
}
size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]);
if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) {
throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache.");
......@@ -83,11 +119,15 @@ struct KVCacheLayer {
if (new_batch_size > cache_positions.size()) {
cache_positions.resize(new_batch_size, 0);
}
auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device);
auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device);
spdlog::debug("Growing KV cache from capacity {} to {} (growth_factor={})",
max_capacity, new_capacity, growth_factor);
// Copy existing cache data
for (size_t b = 0; b < new_batch_size; ++b) {
size_t cache_position = cache_positions[b];
......@@ -104,51 +144,41 @@ struct KVCacheLayer {
max_capacity = new_capacity;
// VALIDATION: Verify cache was grown correctly
// Shape is [batch_size, num_kv_heads, max_capacity, head_dim]
if (k_cache->shape()[2] != new_capacity) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch - expected: {}, got: {}",
new_capacity, k_cache->shape()[2]);
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch");
throw std::runtime_error("KV cache growth: capacity mismatch");
}
}
// VALIDATION: Final check that capacity is sufficient
if (required_capacity > max_capacity) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth - required: {}, max_capacity: {}",
required_capacity, max_capacity);
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth");
throw std::runtime_error("KV cache ensure_capacity: capacity insufficient");
}
}
KVCacheLayer(size_t max_batch_size, size_t n_kv_head, size_t head_dim, infinicore::DataType dtype, size_t max_seqlen = 4096, infinicore::Device device = infinicore::context::getDevice())
: max_capacity(max_seqlen), initialized(false) {
cache_positions = std::vector<size_t>(max_batch_size, 0);
ensure_capacity(max_batch_size, n_kv_head, head_dim, max_capacity, dtype, device);
}
/**
* @brief Update cache with new key and value states
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @param cache_config Cache configuration for capacity management
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
*
* Note: This method writes to the cache. If using with attention op, the attention op
* also writes to the cache, so this should be called AFTER attention, not before.
*/
std::pair<infinicore::Tensor, infinicore::Tensor> update(
const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new) {
const infinicore::Tensor &v_new,
const CacheConfig &cache_config) {
if (k_new->ndim() != 4 || v_new->ndim() != 4) {
throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors in [batch_size, n_kv_head, seq_len, head_dim] form.");
throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors");
}
size_t batch_size = k_new->shape()[0];
size_t num_kv_heads = k_new->shape()[1];
size_t seq_len = k_new->shape()[2];
size_t head_dim = k_new->shape()[3];
// Ensure capacity
// Ensure capacity with cache config
ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len,
k_new->dtype(), k_new->device());
k_new->dtype(), k_new->device(), cache_config);
// Copy new k/v into cache at current position
bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin());
......@@ -185,6 +215,17 @@ struct KVCacheLayer {
*/
class DynamicCache {
public:
/**
* @brief Construct DynamicCache with cache configuration
* @param cache_config Cache configuration parameters
*/
DynamicCache(const CacheConfig &cache_config)
: cache_config_(cache_config), layers_(cache_config.num_layers) {
if (cache_config.num_layers == -1) {
throw std::runtime_error("DynamicCache: num_layers must be specified in CacheConfig");
}
}
/**
* @brief Construct DynamicCache with specified number of layers
*
......@@ -192,18 +233,10 @@ public:
* @param max_position_embeddings Maximum position embeddings (used for initial capacity)
*/
DynamicCache(size_t num_layers, size_t max_position_embeddings = 4096)
: layers_(num_layers), max_position_embeddings_(max_position_embeddings) {}
: cache_config_(CacheConfig(CacheType::DYNAMIC, num_layers, max_position_embeddings)), layers_(num_layers) {}
/**
* @brief Update cache with new key and value states for a specific layer
*
* @param layer_idx Layer index (0-based)
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
*
* This method updates the cache for the specified layer and returns the
* accumulated cache up to the current position.
*/
std::pair<infinicore::Tensor, infinicore::Tensor> update(
size_t layer_idx,
......@@ -215,8 +248,8 @@ public:
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
// Update the cache for this layer
return layers_[layer_idx].update(k_new, v_new);
// Update the cache for this layer with cache config
return layers_[layer_idx].update(k_new, v_new, cache_config_);
}
/**
......@@ -235,6 +268,46 @@ public:
return update(0, k_new, v_new);
}
/**
* @brief Get cache configuration
*/
const CacheConfig &get_config() const { return cache_config_; }
/**
* @brief Update cache configuration (for dynamic reconfiguration)
*/
void update_config(const CacheConfig &new_config) {
// Check if we need to rebuild
bool need_rebuild = false;
// Rebuild if number of layers changed
if (new_config.num_layers != cache_config_.num_layers || new_config.initial_batch_size != cache_config_.initial_batch_size) {
need_rebuild = true;
layers_.resize(new_config.num_layers);
}
// Rebuild if reset mode is RECREATE
if (new_config.reset_mode == CacheResetMode::RECREATE) {
need_rebuild = true;
}
// Update configuration
cache_config_ = new_config;
if (need_rebuild) {
// Clear all layers to force reinitialization on next use
for (auto &layer : layers_) {
layer.initialized = false;
layer.max_capacity = 0;
// Tensors will be recreated when ensure_capacity is called
}
spdlog::info("DynamicCache configuration updated - cache will be rebuilt on next use");
} else {
spdlog::info("DynamicCache configuration updated: layers={}, initial_capacity={}, growth_factor={}",
new_config.num_layers, new_config.initial_capacity, new_config.growth_factor);
}
}
/**
* @brief Get the number of layers in this cache
*/
......@@ -256,7 +329,7 @@ public:
/**
* @brief Get max position embeddings (used for initial capacity)
*/
size_t max_position_embeddings() const { return max_position_embeddings_; }
size_t max_kv_cache_length() const { return cache_config_.max_kv_cache_length; }
/**
* @brief Reset cache for all layers to a specific position
......@@ -264,7 +337,7 @@ public:
* @param pos Position to reset to (defaults to 0)
*/
void reset(size_t pos = 0) {
for (auto& layer : layers_) {
for (auto &layer : layers_) {
std::fill(layer.cache_positions.begin(), layer.cache_positions.end(), pos);
// Note: We don't reset initialized flag or clear the cache tensors
// to avoid reallocation. The cache will be overwritten on next update.
......@@ -274,14 +347,14 @@ public:
/**
* @brief Access a specific layer's cache (for advanced usage)
*/
KVCacheLayer& layer(size_t layer_idx) {
KVCacheLayer &layer(size_t layer_idx) {
if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
return layers_[layer_idx];
}
const KVCacheLayer& layer(size_t layer_idx) const {
const KVCacheLayer &layer(size_t layer_idx) const {
if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
......@@ -289,8 +362,8 @@ public:
}
private:
CacheConfig cache_config_;
std::vector<KVCacheLayer> layers_;
size_t max_position_embeddings_;
};
} // namespace infinilm::cache
#include "infer_engine.hpp"
#include "../models/llama/llama_config.hpp"
#include "spdlog/spdlog.h"
namespace infinilm::engine {
......@@ -9,15 +10,41 @@ namespace infinilm::engine {
InferEngine::InferEngine(
const std::any &config,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type)
infinicore::Device::Type device_type,
const cache::CacheConfig &cache_config) // Changed parameter
: communication_group_(distributed_config, device_type),
model_config_(config) {
model_config_(config),
cache_config_(cache_config) {
spdlog::info("Launch InferEngine with {}", std::string(distributed_config));
spdlog::info("Cache configuration: type={}, layers={}, max_kv_cache_length={}",
static_cast<int>(cache_config_.type),
cache_config_.num_layers,
cache_config_.max_kv_cache_length);
// Try to extract model configuration to override default cache parameters if needed
try {
if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
cache_config_.num_layers = llama_config.num_hidden_layers;
cache_config_.max_kv_cache_length = llama_config.max_position_embeddings;
spdlog::info("Updated cache config from model: layers={}, max_kv_cache_length={}",
cache_config_.num_layers, cache_config_.max_kv_cache_length);
}
} catch (...) {
spdlog::warn("Could not extract model config, using provided CacheConfig");
}
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(model_config_, communication_group_.get_rank_info(r)));
workers_.emplace_back(std::make_unique<RankWorker>(
model_config_,
communication_group_.get_rank_info(r),
cache_config_));
}
}
......@@ -30,11 +57,11 @@ void InferEngine::load_param(const std::string &name, const infinicore::Tensor &
worker->load_param(name, param);
}
}
//------------------------------------------------------
// state_dict
//------------------------------------------------------
std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEngine::state_dict() {
std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> results;
if (0 == workers_.size()) {
throw std::runtime_error(" Model object not found. ");
......@@ -80,10 +107,25 @@ const distributed::DistConfig &InferEngine::get_dist_config() const {
//------------------------------------------------------
// reset_cache
//------------------------------------------------------
void InferEngine::reset_cache(size_t pos, bool async) {
// Reset cache on all workers
void InferEngine::reset_cache(size_t pos) {
for (auto &worker : workers_) {
worker->reset_cache(pos);
}
for (auto &worker : workers_) {
worker->reset_cache(pos, async);
worker->wait();
}
}
//------------------------------------------------------
// reset_cache (overloaded with CacheConfig)
//------------------------------------------------------
void InferEngine::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
cache_config_ = new_config;
for (auto &worker : workers_) {
worker->reset_cache(new_config, pos);
}
for (auto &worker : workers_) {
worker->wait();
}
}
......
......@@ -11,10 +11,12 @@ namespace infinilm::engine {
class InferEngine {
public:
// Updated constructor: accept CacheConfig instead of CacheType
InferEngine(
const std::any &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType());
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig &cache_config = cache::CacheConfig());
// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
......@@ -26,19 +28,24 @@ public:
infinicore::Tensor generate(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids);
// Reset the internal cache in all workers (clears state between generations)
// By default, this is synchronous (blocks until reset completes).
// If async=true, this becomes asynchronous (unstable - use with caution).
void reset_cache(size_t pos = 0, bool async = false);
// Reset the internal cache pos in all workers (clears state between generations)
void reset_cache(size_t pos = 0);
// Overload: reset cache with new KV configuration
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0);
~InferEngine();
const distributed::DistConfig &get_dist_config() const;
// Get current KV configuration
const cache::CacheConfig &get_cache_config() const { return cache_config_; }
protected:
std::vector<std::unique_ptr<RankWorker>> workers_;
distributed::CommunicationGroup communication_group_;
std::any model_config_;
cache::CacheConfig cache_config_;
};
} // namespace infinilm::engine
......@@ -9,14 +9,16 @@
namespace infinilm::engine {
RankWorker::RankWorker(const std::any &model_config,
const distributed::RankInfo &rank_info)
const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config)
: model_config_(model_config),
rank_info_(rank_info),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
should_exit_(false),
init_done_(false) {
init_done_(false),
pending_cache_config_(cache_config) {
// start the thread
thread_ = std::thread(&RankWorker::thread_loop, this);
......@@ -114,8 +116,7 @@ void RankWorker::wait() {
//------------------------------------------------------
// reset_cache -- synchronous by default, async optional (unstable)
//------------------------------------------------------
void RankWorker::reset_cache(size_t pos, bool async) {
{
void RankWorker::reset_cache(size_t pos) {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot reset_cache");
......@@ -125,19 +126,22 @@ void RankWorker::reset_cache(size_t pos, bool async) {
job_cmd_ = Command::RESET_CACHE;
has_job_ = true;
job_done_ = false;
}
cv_.notify_all();
}
// By default, wait for job completion (synchronous)
// If async=true, return immediately (unstable - use with caution)
if (!async) {
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return job_done_ || should_exit_; });
void RankWorker::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker stopped while resetting cache");
}
throw std::runtime_error("RankWorker is closing; cannot reset_cache");
}
// Store both the position and the new config
pending_reset_pos_ = pos;
pending_cache_config_ = new_config;
job_cmd_ = Command::RESET_CACHE_WITH_CONFIG;
has_job_ = true;
job_done_ = false;
cv_.notify_all();
}
//------------------------------------------------------
......@@ -173,8 +177,10 @@ void RankWorker::thread_loop() {
// Initialize device & model outside of holding the main mutex to avoid blocking callers.
infinicore::context::setDevice(rank_info_.device);
cache_ptr_ = std::make_shared<cache::DynamicCache>(pending_cache_config_);
// Create model using factory (may be expensive)
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_);
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, cache_ptr_);
// Signal that initialization is done
{
......@@ -190,6 +196,7 @@ void RankWorker::thread_loop() {
infinicore::Tensor local_param;
std::vector<std::any> local_args;
size_t local_reset_pos = 0;
cache::CacheConfig local_reset_config;
// Wait for a job or exit
{
......@@ -209,6 +216,9 @@ void RankWorker::thread_loop() {
local_args = pending_args_;
} else if (local_cmd == Command::RESET_CACHE) {
local_reset_pos = pending_reset_pos_;
} else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) {
local_reset_pos = pending_reset_pos_;
local_reset_config = pending_cache_config_;
}
// mark job as being processed
......@@ -259,9 +269,15 @@ void RankWorker::thread_loop() {
}
} else if (local_cmd == Command::RESET_CACHE) {
try {
// Generic reset_cache on the model interface
// Option 1: Use model's reset_cache if it handles cache
model_->reset_cache(local_reset_pos);
// Option 2: Reset cache directly if we have access
// if (cache_ptr_ != nullptr) {
// auto* dynamic_cache = static_cast<cache::DynamicCache*>(cache_ptr_);
// dynamic_cache->reset(local_reset_pos);
// }
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
......@@ -276,6 +292,25 @@ void RankWorker::thread_loop() {
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
break;
}
} else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) {
try {
// Use model's reset_cache with new configuration
model_->reset_cache(local_reset_config, local_reset_pos);
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
}
cv_.notify_all();
} catch (const std::exception &e) {
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
cv_.notify_all();
spdlog::error("[{}] exception during reset_cache with config: {}\n", info(), e.what());
break;
}
} else {
// Shouldn't reach here (no-op)
}
......
#pragma once
#include "../cache/cache.hpp"
#include "../models/model_factory.hpp"
#include "distributed/distributed.hpp"
......@@ -18,12 +19,14 @@ class RankWorker {
LOAD,
RUN,
RESET_CACHE,
RESET_CACHE_WITH_CONFIG,
STOP
};
public:
RankWorker(const std::any &model_config,
const distributed::RankInfo &rank_info);
const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config);
// Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name,
......@@ -36,9 +39,10 @@ public:
void run(const std::vector<std::any> &args);
// Reset the internal cache in the model (clears state between generations)
// By default, this is synchronous (blocks until reset completes).
// If async=true, this becomes asynchronous (unstable - use with caution).
void reset_cache(size_t pos = 0, bool async = false);
void reset_cache(size_t pos = 0);
// Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0);
// Wait until run job completes. The result can be retrieved with get_output().
void wait();
......@@ -59,6 +63,7 @@ private:
std::any model_config_;
distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::DynamicCache> cache_ptr_;
// Command for the pending job (protected by mutex_)
Command job_cmd_;
......@@ -74,6 +79,7 @@ private:
infinicore::Tensor pending_param_;
std::vector<std::any> pending_args_;
size_t pending_reset_pos_ = 0;
cache::CacheConfig pending_cache_config_;
// Output (protected by mutex)
infinicore::Tensor output_;
......
......@@ -2,6 +2,8 @@
#include "infinicore/nn/module.hpp"
#include "../cache/cache.hpp"
#include <any>
namespace infinilm {
......@@ -11,5 +13,6 @@ public:
virtual infinicore::Tensor forward(std::vector<std::any>) const = 0;
// Optional: reset cache; default no-op for models without cache
virtual void reset_cache(size_t pos = 0) {}
virtual void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) = 0;
};
} // namespace infinilm
......@@ -66,4 +66,8 @@ void LlamaForCausalLM::reset_cache(size_t pos) {
model_->reset_cache(pos);
}
void LlamaForCausalLM::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
model_->reset_cache(new_config, pos);
}
} // namespace infinilm::models::llama
......@@ -50,6 +50,7 @@ public:
// Reset internal cache position
void reset_cache(size_t pos = 0) override;
void reset_cache(const cache::CacheConfig &new_config, size_t pos) override;
// Module information
const LlamaConfig &config() const { return model_->config(); }
......
......@@ -3,6 +3,7 @@
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include <iostream>
namespace infinilm::models::llama {
......@@ -50,18 +51,20 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// The cache persists across forward calls to enable incremental decoding
void *cache_to_use = kv_cache;
if (kv_cache == nullptr) {
if (cache_to_use == nullptr) {
// Create or reuse persistent internal cache at model level
// This ensures the cache persists across multiple forward calls (prefill -> decode -> decode...)
size_t seq_len = input_ids->shape()[1];
if (!cache_) {
// First time: create cache
cache_ = std::make_unique<infinilm::cache::DynamicCache>(
if (external_cache_ != nullptr) {
cache_to_use = external_cache_;
} else {
// Fall back to internal cache
if (!internal_cache_) {
internal_cache_ = std::make_unique<infinilm::cache::DynamicCache>(
config_.num_hidden_layers,
config_.max_position_embeddings);
}
cache_to_use = cache_.get();
cache_to_use = internal_cache_.get();
}
}
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
......@@ -92,8 +95,22 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
}
void LlamaModel::reset_cache(size_t pos) const {
if (cache_) {
cache_->reset(pos);
if (internal_cache_) {
internal_cache_->reset(pos);
}
if (external_cache_) {
external_cache_->reset(pos);
}
}
void LlamaModel::reset_cache(const cache::CacheConfig &new_config, size_t pos) const {
if (internal_cache_) {
internal_cache_->update_config(new_config);
internal_cache_->reset(pos);
}
if (external_cache_) {
external_cache_->update_config(new_config);
external_cache_->reset(pos);
}
}
......
#pragma once
#include "../../cache/kv_cache.hpp"
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp"
#include "../../cache/kv_cache.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
......@@ -68,6 +67,22 @@ public:
*/
void reset_cache(size_t pos = 0) const;
/**
* @brief Reset the internal cache with a new configuration and position
* This should be called when changing cache parameters (e.g., initial capacity)
* @param new_config New cache configuration
* @param pos Position to reset to
*/
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) const;
/**
* @brief Set external cache for the model
* @param cache Pointer to external cache (managed by CacheManager)
*/
void set_external_cache(std::shared_ptr<cache::DynamicCache> cache) {
external_cache_ = cache.get();
}
protected:
// Token embeddings
INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens);
......@@ -86,7 +101,8 @@ private:
// Persistent cache for when no external cache is provided
// Mutable because it's not part of the model's learned parameters,
// but needs to persist across forward calls for incremental decoding
mutable std::unique_ptr<infinilm::cache::DynamicCache> cache_;
mutable std::unique_ptr<infinilm::cache::DynamicCache> internal_cache_;
cache::DynamicCache *external_cache_ = nullptr;
};
} // namespace infinilm::models::llama
......@@ -2,11 +2,21 @@
#include "llama/llama.hpp"
namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(const std::any &config, engine::distributed::RankInfo rank_info) {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const std::any &config,
engine::distributed::RankInfo rank_info,
std::shared_ptr<cache::DynamicCache> cache_ptr) {
if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
return std::make_shared<models::llama::LlamaForCausalLM>(llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);
auto model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);
if (cache_ptr != nullptr) {
model->model().set_external_cache(cache_ptr);
}
return model;
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
......
......@@ -7,6 +7,6 @@
namespace infinilm {
class InfinilmModelFactory {
public:
static std::shared_ptr<InfinilmModel> createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
static std::shared_ptr<InfinilmModel> createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr<cache::DynamicCache> cache_ptr = nullptr);
};
} // namespace infinilm
#include <pybind11/pybind11.h>
#include "models/llama.hpp"
#include "engine.hpp"
namespace py = pybind11;
......@@ -9,6 +8,8 @@ namespace py = pybind11;
PYBIND11_MODULE(_infinilm, m) {
m.doc() = "InfiniLM Llama model Python bindings";
infinilm::cache::bind_cache_config(m);
infinilm::models::llama::bind_llama(m);
infinilm::engine::distributed::bind_dist_config(m);
infinilm::engine::bind_infer_engine(m);
......
#include "../cache/cache_config.hpp"
#include "../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
......@@ -5,6 +6,60 @@
namespace py = pybind11;
namespace infinilm::cache {
inline void bind_cache_config(py::module &m) {
// First bind the CacheType enum
py::enum_<CacheType>(m, "CacheType")
.value("DYNAMIC", CacheType::DYNAMIC)
.value("PAGED", CacheType::PAGED)
.export_values();
// Then bind the CacheResetMode enum
py::enum_<CacheResetMode>(m, "CacheResetMode")
.value("PRESERVE", CacheResetMode::PRESERVE)
.value("RECREATE", CacheResetMode::RECREATE)
.export_values();
// Finally bind the CacheConfig struct
py::class_<CacheConfig>(m, "CacheConfig")
.def(py::init<>(), "Default constructor")
.def(py::init<CacheType, size_t, size_t>(),
py::arg("type") = CacheType::DYNAMIC,
py::arg("num_layers") = 32,
py::arg("max_kv_cache_length") = 4096,
"Constructor with parameters")
.def_readwrite("type", &CacheConfig::type, "Cache type")
.def_readwrite("num_layers", &CacheConfig::num_layers, "Number of layers")
.def_readwrite("max_kv_cache_length", &CacheConfig::max_kv_cache_length,
"Maximum KV cache length")
.def_readwrite("initial_capacity", &CacheConfig::initial_capacity,
"Initial cache capacity in tokens")
.def_readwrite("initial_batch_size", &CacheConfig::initial_batch_size,
"Initial batch size for cache allocation")
.def_readwrite("growth_factor", &CacheConfig::growth_factor,
"Cache growth factor when resizing (e.g., 2.0 for doubling)")
.def_readwrite("allow_expand", &CacheConfig::allow_expand,
"Whether to allow cache expansion")
.def_readwrite("reset_mode", &CacheConfig::reset_mode,
"Cache reset mode")
.def("__eq__", &CacheConfig::operator==, py::is_operator(),
"Check if two CacheConfig objects are equal")
.def("__ne__", &CacheConfig::operator!=, py::is_operator(),
"Check if two CacheConfig objects are not equal")
.def("__repr__", [](const CacheConfig &cfg) {
return fmt::format("CacheConfig(type={}, num_layers={}, max_kv_cache_length={}, "
"initial_capacity={}, initial_batch_size={}, growth_factor={}, "
"allow_expand={}, reset_mode={})",
static_cast<int>(cfg.type), cfg.num_layers,
cfg.max_kv_cache_length, cfg.initial_capacity,
cfg.initial_batch_size, cfg.growth_factor,
cfg.allow_expand, static_cast<int>(cfg.reset_mode));
});
}
} // namespace infinilm::cache
namespace infinilm::engine::distributed {
inline void bind_dist_config(py::module &m) {
......@@ -29,21 +84,21 @@ inline void bind_dist_config(py::module &m) {
namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine")
.def(py::init([](const infinilm::models::llama::LlamaConfig &cfg,
const infinilm::engine::distributed::DistConfig &dist,
infinicore::Device::Type dev) {
return new InferEngine(std::any(cfg), dist, dev);
infinicore::Device::Type dev,
const infinilm::cache::CacheConfig &cache_config) {
return new InferEngine(std::any(cfg), dist, dev, cache_config);
}),
py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType())
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = infinilm::cache::CacheConfig())
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) {
// Return a dictionary containing references to the whole state of the module.
py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) {
py::dict result;
......@@ -52,15 +107,17 @@ inline void bind_infer_engine(py::module &m) {
}
state_dict_tp_all.append(result);
}
return state_dict_tp_all;
})
.def("generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor { return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>()); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", &InferEngine::reset_cache, py::arg("pos") = 0, py::arg("async") = false, "Reset the internal cache in all workers to a specific position (clears state between generations). "
"By default, this is synchronous. If async=True, this becomes asynchronous (unstable - use with caution).");
// Optionally, you can add __repr__ for debugging
m.attr("InferEngine").attr("__repr__") = py::cpp_function([](const InferEngine &self) {
.def(
"generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor {
return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>());
},
"Run inference on all ranks with arbitrary arguments")
.def("reset_cache", py::overload_cast<size_t>(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position")
.def("reset_cache", py::overload_cast<const cache::CacheConfig &, size_t>(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration")
.def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration")
.def("__repr__", [](const InferEngine &self) {
return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment