Unverified Commit 6498332e authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #124 from InfiniTensor/issue/121

Issue/121 - cache managements
parent 295aacd1
#pragma once
#include "cache_config.hpp"
#include "kv_cache.hpp"
#pragma once
#include <cstddef>
#include <string>
namespace infinilm::cache {
/**
* @enum CacheType
* @brief Enumeration of supported cache types
*/
enum class CacheType {
DYNAMIC, ///< Dynamic KV cache (grows as needed)
PAGED, ///< Paged KV cache (for paged attention)
};
enum class CacheResetMode {
PRESERVE, // Keep cache memory, only reset positions
RECREATE // Recreate cache with new configuration
};
struct CacheConfig {
CacheType type = CacheType::DYNAMIC;
size_t num_layers = 0;
size_t max_kv_cache_length = SIZE_MAX;
size_t initial_capacity = 1024; // Initial cache capacity in tokens
size_t initial_batch_size = 1; // Initial batch size for cache allocation
float growth_factor = 2.0f; // Cache growth factor when resizing
bool allow_expand = true; // Whether to allow cache expansion
CacheResetMode reset_mode = CacheResetMode::PRESERVE;
// Constructor
CacheConfig() = default;
CacheConfig(CacheType type, size_t num_layers = 32, size_t max_kv_cache_length = 4096)
: type(type), num_layers(num_layers), max_kv_cache_length(max_kv_cache_length) {}
bool operator==(const CacheConfig &other) const {
return type == other.type && num_layers == other.num_layers && max_kv_cache_length == other.max_kv_cache_length && initial_capacity == other.initial_capacity && initial_batch_size == other.initial_batch_size && growth_factor == other.growth_factor;
}
bool operator!=(const CacheConfig &other) const {
return !(*this == other);
}
};
} // namespace infinilm::cache
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
#include "infinicore/device.hpp" #include "infinicore/device.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "infinicore/context/context.hpp" #include "cache_config.hpp"
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
...@@ -15,7 +16,6 @@ ...@@ -15,7 +16,6 @@
namespace infinilm::cache { namespace infinilm::cache {
/** /**
* @brief Single layer's KV cache for incremental decoding * @brief Single layer's KV cache for incremental decoding
* *
...@@ -29,22 +29,27 @@ struct KVCacheLayer { ...@@ -29,22 +29,27 @@ struct KVCacheLayer {
infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim] infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim]
std::vector<size_t> cache_positions; // Current position in cache std::vector<size_t> cache_positions; // Current position in cache
size_t max_capacity; // Maximum capacity of cache size_t max_capacity; // Maximum capacity of cache
size_t initial_capacity; // Initial capacity from config
size_t initial_batch_size; // Initial batch size from config
float growth_factor; // Growth factor for dynamic resizing
bool initialized; // Whether cache has been initialized bool initialized; // Whether cache has been initialized
KVCacheLayer() : max_capacity(0), initialized(false) {} KVCacheLayer() : max_capacity(0), initial_capacity(4096), initial_batch_size(1),
growth_factor(2.0f), initialized(false) {}
/** /**
* @brief Initialize or update cache capacity * @brief Initialize or update cache capacity with config parameters
* @param batch_size Current batch size
* @param num_kv_heads Number of key-value heads * @param num_kv_heads Number of key-value heads
* @param head_dim Head dimension * @param head_dim Head dimension
* @param seq_len Sequence length of new tokens * @param seq_len Sequence length of new tokens
* @param dtype Data type * @param dtype Data type
* @param device Device * @param device Device
* @param max_position_embeddings Maximum position embeddings (for initial capacity) * @param cache_config Cache configuration parameters
*/ */
void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len, void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len,
infinicore::DataType dtype, const infinicore::Device &device, infinicore::DataType dtype, const infinicore::Device &device,
size_t max_position_embeddings = 4096) { const CacheConfig &cache_config) {
size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); }); size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); });
// VALIDATION: Verify input parameters // VALIDATION: Verify input parameters
...@@ -54,28 +59,59 @@ struct KVCacheLayer { ...@@ -54,28 +59,59 @@ struct KVCacheLayer {
throw std::runtime_error("KV cache ensure_capacity: invalid parameters"); throw std::runtime_error("KV cache ensure_capacity: invalid parameters");
} }
// Store config parameters on first initialization
if (!initialized) {
initial_capacity = cache_config.initial_capacity;
initial_batch_size = cache_config.initial_batch_size;
growth_factor = cache_config.growth_factor;
}
// Lazy initialization // Lazy initialization
if (!initialized) { if (!initialized) {
max_capacity = std::max(required_capacity, size_t(4096)); // Start with at least 4096 // Use max of required capacity and initial capacity from config
k_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim}, max_capacity = std::max(required_capacity, initial_capacity);
// Use max of current batch size and initial batch size from config
size_t alloc_batch_size = std::max(batch_size, initial_batch_size);
k_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device); dtype, device);
v_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim}, v_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device); dtype, device);
cache_positions = std::vector<size_t>(batch_size, 0); cache_positions = std::vector<size_t>(alloc_batch_size, 0);
initialized = true; initialized = true;
spdlog::debug("Initialized KV cache with batch_size={}, capacity={} (config: initial_batch={}, initial_capacity={})",
alloc_batch_size, max_capacity, initial_batch_size, initial_capacity);
// VALIDATION: Verify cache was created correctly // VALIDATION: Verify cache was created correctly
// Shape is [batch_size, num_kv_heads, max_capacity, head_dim] if (k_cache->shape()[0] != alloc_batch_size || k_cache->shape()[1] != num_kv_heads || k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) {
if (k_cache->shape()[0] != batch_size || k_cache->shape()[1] != num_kv_heads || SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization");
k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization - expected: [{}, {}, {}, {}], got: {}",
batch_size, num_kv_heads, max_capacity, head_dim, k_cache->info());
throw std::runtime_error("KV cache initialization: shape mismatch"); throw std::runtime_error("KV cache initialization: shape mismatch");
} }
} }
// Grow cache if needed (similar to DynamicLayer in Python) // Grow cache if needed using growth factor from config
else if (required_capacity > max_capacity) { else if (required_capacity > max_capacity) {
size_t new_capacity = std::max(max_capacity * 2, required_capacity + max_capacity); if (!cache_config.allow_expand) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache expansion not allowed by config");
throw std::runtime_error("KV cache expansion not allowed");
}
// Calculate new capacity using growth factor
size_t new_capacity = static_cast<size_t>(
std::max(static_cast<float>(max_capacity) * growth_factor,
static_cast<float>(required_capacity + max_capacity)));
// Ensure we don't exceed max_position_embeddings if specified
if (cache_config.max_kv_cache_length != 0) {
new_capacity = std::min(new_capacity, cache_config.max_kv_cache_length);
}
// Ensure we grow by at least some minimum amount
size_t min_growth = 256;
if (new_capacity - max_capacity < min_growth) {
new_capacity = max_capacity + min_growth;
}
size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]); size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]);
if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) { if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) {
throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache."); throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache.");
...@@ -83,11 +119,15 @@ struct KVCacheLayer { ...@@ -83,11 +119,15 @@ struct KVCacheLayer {
if (new_batch_size > cache_positions.size()) { if (new_batch_size > cache_positions.size()) {
cache_positions.resize(new_batch_size, 0); cache_positions.resize(new_batch_size, 0);
} }
auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim}, auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device); dtype, device);
auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim}, auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device); dtype, device);
spdlog::debug("Growing KV cache from capacity {} to {} (growth_factor={})",
max_capacity, new_capacity, growth_factor);
// Copy existing cache data // Copy existing cache data
for (size_t b = 0; b < new_batch_size; ++b) { for (size_t b = 0; b < new_batch_size; ++b) {
size_t cache_position = cache_positions[b]; size_t cache_position = cache_positions[b];
...@@ -104,51 +144,41 @@ struct KVCacheLayer { ...@@ -104,51 +144,41 @@ struct KVCacheLayer {
max_capacity = new_capacity; max_capacity = new_capacity;
// VALIDATION: Verify cache was grown correctly // VALIDATION: Verify cache was grown correctly
// Shape is [batch_size, num_kv_heads, max_capacity, head_dim]
if (k_cache->shape()[2] != new_capacity) { if (k_cache->shape()[2] != new_capacity) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch - expected: {}, got: {}", SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch");
new_capacity, k_cache->shape()[2]);
throw std::runtime_error("KV cache growth: capacity mismatch"); throw std::runtime_error("KV cache growth: capacity mismatch");
} }
} }
// VALIDATION: Final check that capacity is sufficient // VALIDATION: Final check that capacity is sufficient
if (required_capacity > max_capacity) { if (required_capacity > max_capacity) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth - required: {}, max_capacity: {}", SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth");
required_capacity, max_capacity);
throw std::runtime_error("KV cache ensure_capacity: capacity insufficient"); throw std::runtime_error("KV cache ensure_capacity: capacity insufficient");
} }
} }
KVCacheLayer(size_t max_batch_size, size_t n_kv_head, size_t head_dim, infinicore::DataType dtype, size_t max_seqlen = 4096, infinicore::Device device = infinicore::context::getDevice())
: max_capacity(max_seqlen), initialized(false) {
cache_positions = std::vector<size_t>(max_batch_size, 0);
ensure_capacity(max_batch_size, n_kv_head, head_dim, max_capacity, dtype, device);
}
/** /**
* @brief Update cache with new key and value states * @brief Update cache with new key and value states
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim] * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim] * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @param cache_config Cache configuration for capacity management
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim] * @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
*
* Note: This method writes to the cache. If using with attention op, the attention op
* also writes to the cache, so this should be called AFTER attention, not before.
*/ */
std::pair<infinicore::Tensor, infinicore::Tensor> update( std::pair<infinicore::Tensor, infinicore::Tensor> update(
const infinicore::Tensor &k_new, const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new) { const infinicore::Tensor &v_new,
const CacheConfig &cache_config) {
if (k_new->ndim() != 4 || v_new->ndim() != 4) { if (k_new->ndim() != 4 || v_new->ndim() != 4) {
throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors in [batch_size, n_kv_head, seq_len, head_dim] form."); throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors");
} }
size_t batch_size = k_new->shape()[0]; size_t batch_size = k_new->shape()[0];
size_t num_kv_heads = k_new->shape()[1]; size_t num_kv_heads = k_new->shape()[1];
size_t seq_len = k_new->shape()[2]; size_t seq_len = k_new->shape()[2];
size_t head_dim = k_new->shape()[3]; size_t head_dim = k_new->shape()[3];
// Ensure capacity // Ensure capacity with cache config
ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len, ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len,
k_new->dtype(), k_new->device()); k_new->dtype(), k_new->device(), cache_config);
// Copy new k/v into cache at current position // Copy new k/v into cache at current position
bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin()); bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin());
...@@ -185,6 +215,17 @@ struct KVCacheLayer { ...@@ -185,6 +215,17 @@ struct KVCacheLayer {
*/ */
class DynamicCache { class DynamicCache {
public: public:
/**
* @brief Construct DynamicCache with cache configuration
* @param cache_config Cache configuration parameters
*/
DynamicCache(const CacheConfig &cache_config)
: cache_config_(cache_config), layers_(cache_config.num_layers) {
if (cache_config.num_layers == -1) {
throw std::runtime_error("DynamicCache: num_layers must be specified in CacheConfig");
}
}
/** /**
* @brief Construct DynamicCache with specified number of layers * @brief Construct DynamicCache with specified number of layers
* *
...@@ -192,18 +233,10 @@ public: ...@@ -192,18 +233,10 @@ public:
* @param max_position_embeddings Maximum position embeddings (used for initial capacity) * @param max_position_embeddings Maximum position embeddings (used for initial capacity)
*/ */
DynamicCache(size_t num_layers, size_t max_position_embeddings = 4096) DynamicCache(size_t num_layers, size_t max_position_embeddings = 4096)
: layers_(num_layers), max_position_embeddings_(max_position_embeddings) {} : cache_config_(CacheConfig(CacheType::DYNAMIC, num_layers, max_position_embeddings)), layers_(num_layers) {}
/** /**
* @brief Update cache with new key and value states for a specific layer * @brief Update cache with new key and value states for a specific layer
*
* @param layer_idx Layer index (0-based)
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
*
* This method updates the cache for the specified layer and returns the
* accumulated cache up to the current position.
*/ */
std::pair<infinicore::Tensor, infinicore::Tensor> update( std::pair<infinicore::Tensor, infinicore::Tensor> update(
size_t layer_idx, size_t layer_idx,
...@@ -215,8 +248,8 @@ public: ...@@ -215,8 +248,8 @@ public:
throw std::runtime_error("DynamicCache: layer_idx out of range"); throw std::runtime_error("DynamicCache: layer_idx out of range");
} }
// Update the cache for this layer // Update the cache for this layer with cache config
return layers_[layer_idx].update(k_new, v_new); return layers_[layer_idx].update(k_new, v_new, cache_config_);
} }
/** /**
...@@ -235,6 +268,46 @@ public: ...@@ -235,6 +268,46 @@ public:
return update(0, k_new, v_new); return update(0, k_new, v_new);
} }
/**
* @brief Get cache configuration
*/
const CacheConfig &get_config() const { return cache_config_; }
/**
* @brief Update cache configuration (for dynamic reconfiguration)
*/
void update_config(const CacheConfig &new_config) {
// Check if we need to rebuild
bool need_rebuild = false;
// Rebuild if number of layers changed
if (new_config.num_layers != cache_config_.num_layers || new_config.initial_batch_size != cache_config_.initial_batch_size) {
need_rebuild = true;
layers_.resize(new_config.num_layers);
}
// Rebuild if reset mode is RECREATE
if (new_config.reset_mode == CacheResetMode::RECREATE) {
need_rebuild = true;
}
// Update configuration
cache_config_ = new_config;
if (need_rebuild) {
// Clear all layers to force reinitialization on next use
for (auto &layer : layers_) {
layer.initialized = false;
layer.max_capacity = 0;
// Tensors will be recreated when ensure_capacity is called
}
spdlog::info("DynamicCache configuration updated - cache will be rebuilt on next use");
} else {
spdlog::info("DynamicCache configuration updated: layers={}, initial_capacity={}, growth_factor={}",
new_config.num_layers, new_config.initial_capacity, new_config.growth_factor);
}
}
/** /**
* @brief Get the number of layers in this cache * @brief Get the number of layers in this cache
*/ */
...@@ -256,7 +329,7 @@ public: ...@@ -256,7 +329,7 @@ public:
/** /**
* @brief Get max position embeddings (used for initial capacity) * @brief Get max position embeddings (used for initial capacity)
*/ */
size_t max_position_embeddings() const { return max_position_embeddings_; } size_t max_kv_cache_length() const { return cache_config_.max_kv_cache_length; }
/** /**
* @brief Reset cache for all layers to a specific position * @brief Reset cache for all layers to a specific position
...@@ -264,7 +337,7 @@ public: ...@@ -264,7 +337,7 @@ public:
* @param pos Position to reset to (defaults to 0) * @param pos Position to reset to (defaults to 0)
*/ */
void reset(size_t pos = 0) { void reset(size_t pos = 0) {
for (auto& layer : layers_) { for (auto &layer : layers_) {
std::fill(layer.cache_positions.begin(), layer.cache_positions.end(), pos); std::fill(layer.cache_positions.begin(), layer.cache_positions.end(), pos);
// Note: We don't reset initialized flag or clear the cache tensors // Note: We don't reset initialized flag or clear the cache tensors
// to avoid reallocation. The cache will be overwritten on next update. // to avoid reallocation. The cache will be overwritten on next update.
...@@ -274,14 +347,14 @@ public: ...@@ -274,14 +347,14 @@ public:
/** /**
* @brief Access a specific layer's cache (for advanced usage) * @brief Access a specific layer's cache (for advanced usage)
*/ */
KVCacheLayer& layer(size_t layer_idx) { KVCacheLayer &layer(size_t layer_idx) {
if (layer_idx >= layers_.size()) { if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range"); throw std::runtime_error("DynamicCache: layer_idx out of range");
} }
return layers_[layer_idx]; return layers_[layer_idx];
} }
const KVCacheLayer& layer(size_t layer_idx) const { const KVCacheLayer &layer(size_t layer_idx) const {
if (layer_idx >= layers_.size()) { if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range"); throw std::runtime_error("DynamicCache: layer_idx out of range");
} }
...@@ -289,8 +362,8 @@ public: ...@@ -289,8 +362,8 @@ public:
} }
private: private:
CacheConfig cache_config_;
std::vector<KVCacheLayer> layers_; std::vector<KVCacheLayer> layers_;
size_t max_position_embeddings_;
}; };
} // namespace infinilm::cache } // namespace infinilm::cache
#include "infer_engine.hpp" #include "infer_engine.hpp"
#include "../models/llama/llama_config.hpp"
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
namespace infinilm::engine { namespace infinilm::engine {
...@@ -9,15 +10,41 @@ namespace infinilm::engine { ...@@ -9,15 +10,41 @@ namespace infinilm::engine {
InferEngine::InferEngine( InferEngine::InferEngine(
const std::any &config, const std::any &config,
const distributed::DistConfig &distributed_config, const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type) infinicore::Device::Type device_type,
const cache::CacheConfig &cache_config) // Changed parameter
: communication_group_(distributed_config, device_type), : communication_group_(distributed_config, device_type),
model_config_(config) { model_config_(config),
cache_config_(cache_config) {
spdlog::info("Launch InferEngine with {}", std::string(distributed_config)); spdlog::info("Launch InferEngine with {}", std::string(distributed_config));
spdlog::info("Cache configuration: type={}, layers={}, max_kv_cache_length={}",
static_cast<int>(cache_config_.type),
cache_config_.num_layers,
cache_config_.max_kv_cache_length);
// Try to extract model configuration to override default cache parameters if needed
try {
if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
cache_config_.num_layers = llama_config.num_hidden_layers;
cache_config_.max_kv_cache_length = llama_config.max_position_embeddings;
spdlog::info("Updated cache config from model: layers={}, max_kv_cache_length={}",
cache_config_.num_layers, cache_config_.max_kv_cache_length);
}
} catch (...) {
spdlog::warn("Could not extract model config, using provided CacheConfig");
}
// Create one RankWorker per rank // Create one RankWorker per rank
int world_size = communication_group_.get_world_size(); int world_size = communication_group_.get_world_size();
workers_.reserve(world_size); workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) { for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(model_config_, communication_group_.get_rank_info(r))); workers_.emplace_back(std::make_unique<RankWorker>(
model_config_,
communication_group_.get_rank_info(r),
cache_config_));
} }
} }
...@@ -30,11 +57,11 @@ void InferEngine::load_param(const std::string &name, const infinicore::Tensor & ...@@ -30,11 +57,11 @@ void InferEngine::load_param(const std::string &name, const infinicore::Tensor &
worker->load_param(name, param); worker->load_param(name, param);
} }
} }
//------------------------------------------------------ //------------------------------------------------------
// state_dict // state_dict
//------------------------------------------------------ //------------------------------------------------------
std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEngine::state_dict() { std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEngine::state_dict() {
std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> results; std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> results;
if (0 == workers_.size()) { if (0 == workers_.size()) {
throw std::runtime_error(" Model object not found. "); throw std::runtime_error(" Model object not found. ");
...@@ -80,10 +107,25 @@ const distributed::DistConfig &InferEngine::get_dist_config() const { ...@@ -80,10 +107,25 @@ const distributed::DistConfig &InferEngine::get_dist_config() const {
//------------------------------------------------------ //------------------------------------------------------
// reset_cache // reset_cache
//------------------------------------------------------ //------------------------------------------------------
void InferEngine::reset_cache(size_t pos, bool async) { void InferEngine::reset_cache(size_t pos) {
// Reset cache on all workers for (auto &worker : workers_) {
worker->reset_cache(pos);
}
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->reset_cache(pos, async); worker->wait();
}
}
//------------------------------------------------------
// reset_cache (overloaded with CacheConfig)
//------------------------------------------------------
void InferEngine::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
cache_config_ = new_config;
for (auto &worker : workers_) {
worker->reset_cache(new_config, pos);
}
for (auto &worker : workers_) {
worker->wait();
} }
} }
......
...@@ -11,10 +11,12 @@ namespace infinilm::engine { ...@@ -11,10 +11,12 @@ namespace infinilm::engine {
class InferEngine { class InferEngine {
public: public:
// Updated constructor: accept CacheConfig instead of CacheType
InferEngine( InferEngine(
const std::any &config, const std::any &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(), const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType()); infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig &cache_config = cache::CacheConfig());
// Load a parameter to all workers (each can extract its shard inside RankWorker) // Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param); void load_param(const std::string &name, const infinicore::Tensor &param);
...@@ -26,19 +28,24 @@ public: ...@@ -26,19 +28,24 @@ public:
infinicore::Tensor generate(const infinicore::Tensor &input_ids, infinicore::Tensor generate(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids); const infinicore::Tensor &position_ids);
// Reset the internal cache in all workers (clears state between generations) // Reset the internal cache pos in all workers (clears state between generations)
// By default, this is synchronous (blocks until reset completes). void reset_cache(size_t pos = 0);
// If async=true, this becomes asynchronous (unstable - use with caution).
void reset_cache(size_t pos = 0, bool async = false); // Overload: reset cache with new KV configuration
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0);
~InferEngine(); ~InferEngine();
const distributed::DistConfig &get_dist_config() const; const distributed::DistConfig &get_dist_config() const;
// Get current KV configuration
const cache::CacheConfig &get_cache_config() const { return cache_config_; }
protected: protected:
std::vector<std::unique_ptr<RankWorker>> workers_; std::vector<std::unique_ptr<RankWorker>> workers_;
distributed::CommunicationGroup communication_group_; distributed::CommunicationGroup communication_group_;
std::any model_config_; std::any model_config_;
cache::CacheConfig cache_config_;
}; };
} // namespace infinilm::engine } // namespace infinilm::engine
...@@ -9,14 +9,16 @@ ...@@ -9,14 +9,16 @@
namespace infinilm::engine { namespace infinilm::engine {
RankWorker::RankWorker(const std::any &model_config, RankWorker::RankWorker(const std::any &model_config,
const distributed::RankInfo &rank_info) const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config)
: model_config_(model_config), : model_config_(model_config),
rank_info_(rank_info), rank_info_(rank_info),
job_cmd_(Command::INIT), job_cmd_(Command::INIT),
has_job_(false), has_job_(false),
job_done_(false), job_done_(false),
should_exit_(false), should_exit_(false),
init_done_(false) { init_done_(false),
pending_cache_config_(cache_config) {
// start the thread // start the thread
thread_ = std::thread(&RankWorker::thread_loop, this); thread_ = std::thread(&RankWorker::thread_loop, this);
...@@ -114,8 +116,7 @@ void RankWorker::wait() { ...@@ -114,8 +116,7 @@ void RankWorker::wait() {
//------------------------------------------------------ //------------------------------------------------------
// reset_cache -- synchronous by default, async optional (unstable) // reset_cache -- synchronous by default, async optional (unstable)
//------------------------------------------------------ //------------------------------------------------------
void RankWorker::reset_cache(size_t pos, bool async) { void RankWorker::reset_cache(size_t pos) {
{
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) { if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot reset_cache"); throw std::runtime_error("RankWorker is closing; cannot reset_cache");
...@@ -125,19 +126,22 @@ void RankWorker::reset_cache(size_t pos, bool async) { ...@@ -125,19 +126,22 @@ void RankWorker::reset_cache(size_t pos, bool async) {
job_cmd_ = Command::RESET_CACHE; job_cmd_ = Command::RESET_CACHE;
has_job_ = true; has_job_ = true;
job_done_ = false; job_done_ = false;
}
cv_.notify_all(); cv_.notify_all();
}
// By default, wait for job completion (synchronous) void RankWorker::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
// If async=true, return immediately (unstable - use with caution) std::lock_guard<std::mutex> lock(mutex_);
if (!async) {
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return job_done_ || should_exit_; });
if (should_exit_) { if (should_exit_) {
throw std::runtime_error("RankWorker stopped while resetting cache"); throw std::runtime_error("RankWorker is closing; cannot reset_cache");
}
} }
// Store both the position and the new config
pending_reset_pos_ = pos;
pending_cache_config_ = new_config;
job_cmd_ = Command::RESET_CACHE_WITH_CONFIG;
has_job_ = true;
job_done_ = false;
cv_.notify_all();
} }
//------------------------------------------------------ //------------------------------------------------------
...@@ -173,8 +177,10 @@ void RankWorker::thread_loop() { ...@@ -173,8 +177,10 @@ void RankWorker::thread_loop() {
// Initialize device & model outside of holding the main mutex to avoid blocking callers. // Initialize device & model outside of holding the main mutex to avoid blocking callers.
infinicore::context::setDevice(rank_info_.device); infinicore::context::setDevice(rank_info_.device);
cache_ptr_ = std::make_shared<cache::DynamicCache>(pending_cache_config_);
// Create model using factory (may be expensive) // Create model using factory (may be expensive)
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_); model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, cache_ptr_);
// Signal that initialization is done // Signal that initialization is done
{ {
...@@ -190,6 +196,7 @@ void RankWorker::thread_loop() { ...@@ -190,6 +196,7 @@ void RankWorker::thread_loop() {
infinicore::Tensor local_param; infinicore::Tensor local_param;
std::vector<std::any> local_args; std::vector<std::any> local_args;
size_t local_reset_pos = 0; size_t local_reset_pos = 0;
cache::CacheConfig local_reset_config;
// Wait for a job or exit // Wait for a job or exit
{ {
...@@ -209,6 +216,9 @@ void RankWorker::thread_loop() { ...@@ -209,6 +216,9 @@ void RankWorker::thread_loop() {
local_args = pending_args_; local_args = pending_args_;
} else if (local_cmd == Command::RESET_CACHE) { } else if (local_cmd == Command::RESET_CACHE) {
local_reset_pos = pending_reset_pos_; local_reset_pos = pending_reset_pos_;
} else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) {
local_reset_pos = pending_reset_pos_;
local_reset_config = pending_cache_config_;
} }
// mark job as being processed // mark job as being processed
...@@ -259,9 +269,15 @@ void RankWorker::thread_loop() { ...@@ -259,9 +269,15 @@ void RankWorker::thread_loop() {
} }
} else if (local_cmd == Command::RESET_CACHE) { } else if (local_cmd == Command::RESET_CACHE) {
try { try {
// Generic reset_cache on the model interface // Option 1: Use model's reset_cache if it handles cache
model_->reset_cache(local_reset_pos); model_->reset_cache(local_reset_pos);
// Option 2: Reset cache directly if we have access
// if (cache_ptr_ != nullptr) {
// auto* dynamic_cache = static_cast<cache::DynamicCache*>(cache_ptr_);
// dynamic_cache->reset(local_reset_pos);
// }
{ {
std::lock_guard<std::mutex> lk(mutex_); std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true; job_done_ = true;
...@@ -276,6 +292,25 @@ void RankWorker::thread_loop() { ...@@ -276,6 +292,25 @@ void RankWorker::thread_loop() {
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what()); spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
break; break;
} }
} else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) {
try {
// Use model's reset_cache with new configuration
model_->reset_cache(local_reset_config, local_reset_pos);
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
}
cv_.notify_all();
} catch (const std::exception &e) {
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
cv_.notify_all();
spdlog::error("[{}] exception during reset_cache with config: {}\n", info(), e.what());
break;
}
} else { } else {
// Shouldn't reach here (no-op) // Shouldn't reach here (no-op)
} }
......
#pragma once #pragma once
#include "../cache/cache.hpp"
#include "../models/model_factory.hpp" #include "../models/model_factory.hpp"
#include "distributed/distributed.hpp" #include "distributed/distributed.hpp"
...@@ -18,12 +19,14 @@ class RankWorker { ...@@ -18,12 +19,14 @@ class RankWorker {
LOAD, LOAD,
RUN, RUN,
RESET_CACHE, RESET_CACHE,
RESET_CACHE_WITH_CONFIG,
STOP STOP
}; };
public: public:
RankWorker(const std::any &model_config, RankWorker(const std::any &model_config,
const distributed::RankInfo &rank_info); const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config);
// Submit a parameter load job and wait until the load completes on the worker thread. // Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name, void load_param(const std::string &name,
...@@ -36,9 +39,10 @@ public: ...@@ -36,9 +39,10 @@ public:
void run(const std::vector<std::any> &args); void run(const std::vector<std::any> &args);
// Reset the internal cache in the model (clears state between generations) // Reset the internal cache in the model (clears state between generations)
// By default, this is synchronous (blocks until reset completes). void reset_cache(size_t pos = 0);
// If async=true, this becomes asynchronous (unstable - use with caution).
void reset_cache(size_t pos = 0, bool async = false); // Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0);
// Wait until run job completes. The result can be retrieved with get_output(). // Wait until run job completes. The result can be retrieved with get_output().
void wait(); void wait();
...@@ -59,6 +63,7 @@ private: ...@@ -59,6 +63,7 @@ private:
std::any model_config_; std::any model_config_;
distributed::RankInfo rank_info_; distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_; std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::DynamicCache> cache_ptr_;
// Command for the pending job (protected by mutex_) // Command for the pending job (protected by mutex_)
Command job_cmd_; Command job_cmd_;
...@@ -74,6 +79,7 @@ private: ...@@ -74,6 +79,7 @@ private:
infinicore::Tensor pending_param_; infinicore::Tensor pending_param_;
std::vector<std::any> pending_args_; std::vector<std::any> pending_args_;
size_t pending_reset_pos_ = 0; size_t pending_reset_pos_ = 0;
cache::CacheConfig pending_cache_config_;
// Output (protected by mutex) // Output (protected by mutex)
infinicore::Tensor output_; infinicore::Tensor output_;
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include "../cache/cache.hpp"
#include <any> #include <any>
namespace infinilm { namespace infinilm {
...@@ -11,5 +13,6 @@ public: ...@@ -11,5 +13,6 @@ public:
virtual infinicore::Tensor forward(std::vector<std::any>) const = 0; virtual infinicore::Tensor forward(std::vector<std::any>) const = 0;
// Optional: reset cache; default no-op for models without cache // Optional: reset cache; default no-op for models without cache
virtual void reset_cache(size_t pos = 0) {} virtual void reset_cache(size_t pos = 0) {}
virtual void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) = 0;
}; };
} // namespace infinilm } // namespace infinilm
...@@ -66,4 +66,8 @@ void LlamaForCausalLM::reset_cache(size_t pos) { ...@@ -66,4 +66,8 @@ void LlamaForCausalLM::reset_cache(size_t pos) {
model_->reset_cache(pos); model_->reset_cache(pos);
} }
void LlamaForCausalLM::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
model_->reset_cache(new_config, pos);
}
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -50,6 +50,7 @@ public: ...@@ -50,6 +50,7 @@ public:
// Reset internal cache position // Reset internal cache position
void reset_cache(size_t pos = 0) override; void reset_cache(size_t pos = 0) override;
void reset_cache(const cache::CacheConfig &new_config, size_t pos) override;
// Module information // Module information
const LlamaConfig &config() const { return model_->config(); } const LlamaConfig &config() const { return model_->config(); }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "infinicore/nn/rmsnorm.hpp" #include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp" #include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include <iostream>
namespace infinilm::models::llama { namespace infinilm::models::llama {
...@@ -50,18 +51,20 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -50,18 +51,20 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// The cache persists across forward calls to enable incremental decoding // The cache persists across forward calls to enable incremental decoding
void *cache_to_use = kv_cache; void *cache_to_use = kv_cache;
if (kv_cache == nullptr) { if (cache_to_use == nullptr) {
// Create or reuse persistent internal cache at model level // Create or reuse persistent internal cache at model level
// This ensures the cache persists across multiple forward calls (prefill -> decode -> decode...) // This ensures the cache persists across multiple forward calls (prefill -> decode -> decode...)
size_t seq_len = input_ids->shape()[1]; if (external_cache_ != nullptr) {
cache_to_use = external_cache_;
if (!cache_) { } else {
// First time: create cache // Fall back to internal cache
cache_ = std::make_unique<infinilm::cache::DynamicCache>( if (!internal_cache_) {
internal_cache_ = std::make_unique<infinilm::cache::DynamicCache>(
config_.num_hidden_layers, config_.num_hidden_layers,
config_.max_position_embeddings); config_.max_position_embeddings);
} }
cache_to_use = cache_.get(); cache_to_use = internal_cache_.get();
}
} }
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
...@@ -92,8 +95,22 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -92,8 +95,22 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
} }
void LlamaModel::reset_cache(size_t pos) const { void LlamaModel::reset_cache(size_t pos) const {
if (cache_) { if (internal_cache_) {
cache_->reset(pos); internal_cache_->reset(pos);
}
if (external_cache_) {
external_cache_->reset(pos);
}
}
void LlamaModel::reset_cache(const cache::CacheConfig &new_config, size_t pos) const {
if (internal_cache_) {
internal_cache_->update_config(new_config);
internal_cache_->reset(pos);
}
if (external_cache_) {
external_cache_->update_config(new_config);
external_cache_->reset(pos);
} }
} }
......
#pragma once #pragma once
#include "../../cache/kv_cache.hpp"
#include "llama_config.hpp" #include "llama_config.hpp"
#include "llama_decoder_layer.hpp" #include "llama_decoder_layer.hpp"
#include "../../cache/kv_cache.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/embedding.hpp" #include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp" #include "infinicore/nn/rmsnorm.hpp"
...@@ -68,6 +67,22 @@ public: ...@@ -68,6 +67,22 @@ public:
*/ */
void reset_cache(size_t pos = 0) const; void reset_cache(size_t pos = 0) const;
/**
* @brief Reset the internal cache with a new configuration and position
* This should be called when changing cache parameters (e.g., initial capacity)
* @param new_config New cache configuration
* @param pos Position to reset to
*/
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) const;
/**
* @brief Set external cache for the model
* @param cache Pointer to external cache (managed by CacheManager)
*/
void set_external_cache(std::shared_ptr<cache::DynamicCache> cache) {
external_cache_ = cache.get();
}
protected: protected:
// Token embeddings // Token embeddings
INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens);
...@@ -86,7 +101,8 @@ private: ...@@ -86,7 +101,8 @@ private:
// Persistent cache for when no external cache is provided // Persistent cache for when no external cache is provided
// Mutable because it's not part of the model's learned parameters, // Mutable because it's not part of the model's learned parameters,
// but needs to persist across forward calls for incremental decoding // but needs to persist across forward calls for incremental decoding
mutable std::unique_ptr<infinilm::cache::DynamicCache> cache_; mutable std::unique_ptr<infinilm::cache::DynamicCache> internal_cache_;
cache::DynamicCache *external_cache_ = nullptr;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -2,11 +2,21 @@ ...@@ -2,11 +2,21 @@
#include "llama/llama.hpp" #include "llama/llama.hpp"
namespace infinilm { namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(const std::any &config, engine::distributed::RankInfo rank_info) { std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const std::any &config,
engine::distributed::RankInfo rank_info,
std::shared_ptr<cache::DynamicCache> cache_ptr) {
if (config.type() == typeid(models::llama::LlamaConfig)) { if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config); const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
return std::make_shared<models::llama::LlamaForCausalLM>(llama_config, rank_info.device, infinicore::DataType::BF16, rank_info); auto model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);
if (cache_ptr != nullptr) {
model->model().set_external_cache(cache_ptr);
}
return model;
} else { } else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
} }
......
...@@ -7,6 +7,6 @@ ...@@ -7,6 +7,6 @@
namespace infinilm { namespace infinilm {
class InfinilmModelFactory { class InfinilmModelFactory {
public: public:
static std::shared_ptr<InfinilmModel> createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); static std::shared_ptr<InfinilmModel> createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr<cache::DynamicCache> cache_ptr = nullptr);
}; };
} // namespace infinilm } // namespace infinilm
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "models/llama.hpp" #include "models/llama.hpp"
#include "engine.hpp" #include "engine.hpp"
namespace py = pybind11; namespace py = pybind11;
...@@ -9,6 +8,8 @@ namespace py = pybind11; ...@@ -9,6 +8,8 @@ namespace py = pybind11;
PYBIND11_MODULE(_infinilm, m) { PYBIND11_MODULE(_infinilm, m) {
m.doc() = "InfiniLM Llama model Python bindings"; m.doc() = "InfiniLM Llama model Python bindings";
infinilm::cache::bind_cache_config(m);
infinilm::models::llama::bind_llama(m); infinilm::models::llama::bind_llama(m);
infinilm::engine::distributed::bind_dist_config(m); infinilm::engine::distributed::bind_dist_config(m);
infinilm::engine::bind_infer_engine(m); infinilm::engine::bind_infer_engine(m);
......
#include "../cache/cache_config.hpp"
#include "../engine/infer_engine.hpp" #include "../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
...@@ -5,6 +6,60 @@ ...@@ -5,6 +6,60 @@
namespace py = pybind11; namespace py = pybind11;
namespace infinilm::cache {
inline void bind_cache_config(py::module &m) {
// First bind the CacheType enum
py::enum_<CacheType>(m, "CacheType")
.value("DYNAMIC", CacheType::DYNAMIC)
.value("PAGED", CacheType::PAGED)
.export_values();
// Then bind the CacheResetMode enum
py::enum_<CacheResetMode>(m, "CacheResetMode")
.value("PRESERVE", CacheResetMode::PRESERVE)
.value("RECREATE", CacheResetMode::RECREATE)
.export_values();
// Finally bind the CacheConfig struct
py::class_<CacheConfig>(m, "CacheConfig")
.def(py::init<>(), "Default constructor")
.def(py::init<CacheType, size_t, size_t>(),
py::arg("type") = CacheType::DYNAMIC,
py::arg("num_layers") = 32,
py::arg("max_kv_cache_length") = 4096,
"Constructor with parameters")
.def_readwrite("type", &CacheConfig::type, "Cache type")
.def_readwrite("num_layers", &CacheConfig::num_layers, "Number of layers")
.def_readwrite("max_kv_cache_length", &CacheConfig::max_kv_cache_length,
"Maximum KV cache length")
.def_readwrite("initial_capacity", &CacheConfig::initial_capacity,
"Initial cache capacity in tokens")
.def_readwrite("initial_batch_size", &CacheConfig::initial_batch_size,
"Initial batch size for cache allocation")
.def_readwrite("growth_factor", &CacheConfig::growth_factor,
"Cache growth factor when resizing (e.g., 2.0 for doubling)")
.def_readwrite("allow_expand", &CacheConfig::allow_expand,
"Whether to allow cache expansion")
.def_readwrite("reset_mode", &CacheConfig::reset_mode,
"Cache reset mode")
.def("__eq__", &CacheConfig::operator==, py::is_operator(),
"Check if two CacheConfig objects are equal")
.def("__ne__", &CacheConfig::operator!=, py::is_operator(),
"Check if two CacheConfig objects are not equal")
.def("__repr__", [](const CacheConfig &cfg) {
return fmt::format("CacheConfig(type={}, num_layers={}, max_kv_cache_length={}, "
"initial_capacity={}, initial_batch_size={}, growth_factor={}, "
"allow_expand={}, reset_mode={})",
static_cast<int>(cfg.type), cfg.num_layers,
cfg.max_kv_cache_length, cfg.initial_capacity,
cfg.initial_batch_size, cfg.growth_factor,
cfg.allow_expand, static_cast<int>(cfg.reset_mode));
});
}
} // namespace infinilm::cache
namespace infinilm::engine::distributed { namespace infinilm::engine::distributed {
inline void bind_dist_config(py::module &m) { inline void bind_dist_config(py::module &m) {
...@@ -29,21 +84,21 @@ inline void bind_dist_config(py::module &m) { ...@@ -29,21 +84,21 @@ inline void bind_dist_config(py::module &m) {
namespace infinilm::engine { namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) { inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine") py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine")
.def(py::init([](const infinilm::models::llama::LlamaConfig &cfg, .def(py::init([](const infinilm::models::llama::LlamaConfig &cfg,
const infinilm::engine::distributed::DistConfig &dist, const infinilm::engine::distributed::DistConfig &dist,
infinicore::Device::Type dev) { infinicore::Device::Type dev,
return new InferEngine(std::any(cfg), dist, dev); const infinilm::cache::CacheConfig &cache_config) {
return new InferEngine(std::any(cfg), dist, dev, cache_config);
}), }),
py::arg("config"), py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(), py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType()) py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = infinilm::cache::CacheConfig())
.def("load_param", &InferEngine::load_param, .def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"), py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)") "Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) { .def("state_dict", [](InferEngine &self) {
// Return a dictionary containing references to the whole state of the module.
py::list state_dict_tp_all; py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) { for (const auto &state_dict_tp : self.state_dict()) {
py::dict result; py::dict result;
...@@ -52,15 +107,17 @@ inline void bind_infer_engine(py::module &m) { ...@@ -52,15 +107,17 @@ inline void bind_infer_engine(py::module &m) {
} }
state_dict_tp_all.append(result); state_dict_tp_all.append(result);
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def("generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor { return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>()); }, "Run inference on all ranks with arbitrary arguments") .def(
.def("reset_cache", &InferEngine::reset_cache, py::arg("pos") = 0, py::arg("async") = false, "Reset the internal cache in all workers to a specific position (clears state between generations). " "generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor {
"By default, this is synchronous. If async=True, this becomes asynchronous (unstable - use with caution)."); return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>());
},
// Optionally, you can add __repr__ for debugging "Run inference on all ranks with arbitrary arguments")
m.attr("InferEngine").attr("__repr__") = py::cpp_function([](const InferEngine &self) { .def("reset_cache", py::overload_cast<size_t>(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position")
.def("reset_cache", py::overload_cast<const cache::CacheConfig &, size_t>(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration")
.def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration")
.def("__repr__", [](const InferEngine &self) {
return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment