Commit ff00b5c8 authored by PanZezhong's avatar PanZezhong
Browse files

issue/125 统一Cache接口

parent 13a4154a
#pragma once
#include "../engine/distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
namespace infinilm::cache {
class Cache {
public:
Cache() = default;
virtual ~Cache() {}
};
class CacheConfig {
public:
CacheConfig() = default;
virtual ~CacheConfig() {}
virtual std::unique_ptr<CacheConfig> unique_copy() const = 0;
};
} // namespace infinilm::cache
#pragma once #pragma once
#include "cache_config.hpp" #include "base_cache.hpp"
#include "kv_cache.hpp" #include "kv_cache.hpp"
#pragma once
#include <cstddef>
#include <string>
#include <cstdint>
namespace infinilm::cache {
/**
* @enum CacheType
* @brief Enumeration of supported cache types
*/
enum class CacheType {
DYNAMIC, ///< Dynamic KV cache (grows as needed)
PAGED, ///< Paged KV cache (for paged attention)
};
enum class CacheResetMode {
PRESERVE, // Keep cache memory, only reset positions
RECREATE // Recreate cache with new configuration
};
struct CacheConfig {
CacheType type = CacheType::DYNAMIC;
size_t num_layers = 0;
size_t max_kv_cache_length = SIZE_MAX;
size_t initial_capacity = 1024; // Initial cache capacity in tokens
size_t initial_batch_size = 1; // Initial batch size for cache allocation
float growth_factor = 2.0f; // Cache growth factor when resizing
bool allow_expand = true; // Whether to allow cache expansion
CacheResetMode reset_mode = CacheResetMode::PRESERVE;
// Constructor
CacheConfig() = default;
CacheConfig(CacheType type, size_t num_layers = 32, size_t max_kv_cache_length = 4096)
: type(type), num_layers(num_layers), max_kv_cache_length(max_kv_cache_length) {}
bool operator==(const CacheConfig &other) const {
return type == other.type && num_layers == other.num_layers && max_kv_cache_length == other.max_kv_cache_length && initial_capacity == other.initial_capacity && initial_batch_size == other.initial_batch_size && growth_factor == other.growth_factor;
}
bool operator!=(const CacheConfig &other) const {
return !(*this == other);
}
};
} // namespace infinilm::cache
#include "kv_cache.hpp"
#include "../utils.hpp"
#include <stdexcept>
namespace infinilm::cache {
// ==========================
// StaticKVCache
// ==========================
StaticKVCache::StaticKVCache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::Size max_positional_embedding,
infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
k_dim_(k_dim),
v_dim_(v_dim),
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_batch_size_(config.max_batch_size()),
cache_len_(std::min(config.max_cache_len(), max_positional_embedding)),
rank_num_layers_(num_layers),
dtype_(dtype) {
// Allocate K cache
k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
rank_batch_size_,
num_rank_k_heads_,
cache_len_,
k_dim_},
dtype_,
rank_info.device);
// Allocate V cache
v_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
rank_batch_size_,
num_rank_v_heads_,
cache_len_,
v_dim_},
dtype_,
rank_info.device);
spdlog::info("Created Static KV Cache: K[{}] V[{}]", k_caches_->info(), v_caches_->info());
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache::update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &cache_positions) {
ASSERT(layer_idx < rank_num_layers_);
auto batch_size = k->size(0);
auto update_len = k->size(2);
size_t cache_pos = reinterpret_cast<int64_t *>(cache_positions->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
ASSERT_EQ(batch_size, rank_batch_size_);
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
k_cache_update->copy_from(k);
v_cache_update->copy_from(v);
auto k_total = k_cache_layer->narrow({{2, 0, result_len}});
auto v_total = v_cache_layer->narrow({{2, 0, result_len}});
return {k_total, v_total};
}
// ==========================
// StaticKVCacheConfig
// ==========================
StaticKVCacheConfig::StaticKVCacheConfig(
infinicore::Size _max_batch_size,
infinicore::Size _max_cache_len)
: max_batch_size_(_max_batch_size),
max_cache_len_(_max_cache_len) {
}
std::unique_ptr<CacheConfig>
StaticKVCacheConfig::unique_copy() const {
return std::make_unique<StaticKVCacheConfig>(*this);
}
infinicore::Size
StaticKVCacheConfig::max_batch_size() const {
return max_batch_size_;
}
infinicore::Size
StaticKVCacheConfig::max_cache_len() const {
return max_cache_len_;
}
} // namespace infinilm::cache
#pragma once #pragma once
#include "base_cache.hpp"
#include "infinicore/context/context.hpp" #include "infinicore/context/context.hpp"
#include "infinicore/device.hpp" #include "infinicore/device.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "cache_config.hpp"
#include <algorithm> #include <algorithm>
#include <limits>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <stdexcept> #include <stdexcept>
...@@ -15,355 +16,70 @@ ...@@ -15,355 +16,70 @@
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
namespace infinilm::cache { namespace infinilm::cache {
class StaticKVCacheConfig final : public CacheConfig {
public:
StaticKVCacheConfig(
infinicore::Size _max_batch_size = 1,
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());
/** std::unique_ptr<CacheConfig> unique_copy() const override;
* @brief Single layer's KV cache for incremental decoding infinicore::Size max_batch_size() const;
* infinicore::Size max_cache_len() const;
* Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim]
* Similar to DynamicLayer in Python cache_utils.py
*
* This represents a single layer's cache within a model-level cache container.
*/
struct KVCacheLayer {
infinicore::Tensor k_cache; // [batch_size, n_kv_head, capacity, head_dim]
infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim]
std::vector<size_t> cache_positions; // Current position in cache
size_t max_capacity; // Maximum capacity of cache
size_t initial_capacity; // Initial capacity from config
size_t initial_batch_size; // Initial batch size from config
float growth_factor; // Growth factor for dynamic resizing
bool initialized; // Whether cache has been initialized
KVCacheLayer() : max_capacity(0), initial_capacity(4096), initial_batch_size(1),
growth_factor(2.0f), initialized(false) {}
/**
* @brief Initialize or update cache capacity with config parameters
* @param batch_size Current batch size
* @param num_kv_heads Number of key-value heads
* @param head_dim Head dimension
* @param seq_len Sequence length of new tokens
* @param dtype Data type
* @param device Device
* @param cache_config Cache configuration parameters
*/
void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len,
infinicore::DataType dtype, const infinicore::Device &device,
const CacheConfig &cache_config) {
size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); });
// VALIDATION: Verify input parameters
if (num_kv_heads == 0 || head_dim == 0 || seq_len == 0) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Invalid parameters - num_kv_heads: {}, head_dim: {}, seq_len: {}",
num_kv_heads, head_dim, seq_len);
throw std::runtime_error("KV cache ensure_capacity: invalid parameters");
}
// Store config parameters on first initialization
if (!initialized) {
initial_capacity = cache_config.initial_capacity;
initial_batch_size = cache_config.initial_batch_size;
growth_factor = cache_config.growth_factor;
}
// Lazy initialization
if (!initialized) {
// Use max of required capacity and initial capacity from config
max_capacity = std::max(required_capacity, initial_capacity);
// Use max of current batch size and initial batch size from config
size_t alloc_batch_size = std::max(batch_size, initial_batch_size);
k_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device);
v_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device);
cache_positions = std::vector<size_t>(alloc_batch_size, 0);
initialized = true;
spdlog::debug("Initialized KV cache with batch_size={}, capacity={} (config: initial_batch={}, initial_capacity={})",
alloc_batch_size, max_capacity, initial_batch_size, initial_capacity);
// VALIDATION: Verify cache was created correctly
if (k_cache->shape()[0] != alloc_batch_size || k_cache->shape()[1] != num_kv_heads || k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization");
throw std::runtime_error("KV cache initialization: shape mismatch");
}
}
// Grow cache if needed using growth factor from config
else if (required_capacity > max_capacity) {
if (!cache_config.allow_expand) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache expansion not allowed by config");
throw std::runtime_error("KV cache expansion not allowed");
}
// Calculate new capacity using growth factor
size_t new_capacity = static_cast<size_t>(
std::max(static_cast<float>(max_capacity) * growth_factor,
static_cast<float>(required_capacity + max_capacity)));
// Ensure we don't exceed max_position_embeddings if specified
if (cache_config.max_kv_cache_length != 0) {
new_capacity = std::min(new_capacity, cache_config.max_kv_cache_length);
}
// Ensure we grow by at least some minimum amount
size_t min_growth = 256;
if (new_capacity - max_capacity < min_growth) {
new_capacity = max_capacity + min_growth;
}
size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]);
if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) {
throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache.");
}
if (new_batch_size > cache_positions.size()) {
cache_positions.resize(new_batch_size, 0);
}
auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device);
auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device);
spdlog::debug("Growing KV cache from capacity {} to {} (growth_factor={})",
max_capacity, new_capacity, growth_factor);
// Copy existing cache data
for (size_t b = 0; b < new_batch_size; ++b) {
size_t cache_position = cache_positions[b];
if (cache_position > 0) {
auto k_slice = k_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
auto v_slice = v_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
k_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(k_slice);
v_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(v_slice);
}
}
k_cache = k_new;
v_cache = v_new;
max_capacity = new_capacity;
// VALIDATION: Verify cache was grown correctly
if (k_cache->shape()[2] != new_capacity) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch");
throw std::runtime_error("KV cache growth: capacity mismatch");
}
}
// VALIDATION: Final check that capacity is sufficient
if (required_capacity > max_capacity) {
SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth");
throw std::runtime_error("KV cache ensure_capacity: capacity insufficient");
}
}
/**
* @brief Update cache with new key and value states
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @param cache_config Cache configuration for capacity management
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
*/
std::pair<infinicore::Tensor, infinicore::Tensor> update(
const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new,
const CacheConfig &cache_config) {
if (k_new->ndim() != 4 || v_new->ndim() != 4) {
throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors");
}
size_t batch_size = k_new->shape()[0];
size_t num_kv_heads = k_new->shape()[1];
size_t seq_len = k_new->shape()[2];
size_t head_dim = k_new->shape()[3];
// Ensure capacity with cache config
ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len,
k_new->dtype(), k_new->device(), cache_config);
// Copy new k/v into cache at current position
bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin());
if (all_equal) {
auto cache_position = cache_positions[0];
auto k_dst = k_cache->narrow({{2, cache_position, seq_len}});
auto v_dst = v_cache->narrow({{2, cache_position, seq_len}});
k_dst->copy_from(k_new);
v_dst->copy_from(v_new);
// Update position
cache_position += seq_len;
for (size_t b = 0; b < batch_size; ++b) {
cache_positions[b] = cache_position;
}
// Return the total cache up to current position
auto k_total = k_cache->narrow({{2, 0, cache_position}});
auto v_total = v_cache->narrow({{2, 0, cache_position}});
return std::make_pair(k_total, v_total); private:
} else { infinicore::Size max_batch_size_;
throw std::runtime_error("KVCache update: cache positions must be equal among a batch."); infinicore::Size max_cache_len_;
}
}
}; };
/** class StaticKVCache final : public Cache {
* @brief Model-level KV cache container (similar to DynamicCache in Python)
*
* Stores a list of KVCacheLayer objects, one per model layer.
* This aligns with Python backend's DynamicCache architecture.
*/
class DynamicCache {
public: public:
/** StaticKVCache(
* @brief Construct DynamicCache with cache configuration
* @param cache_config Cache configuration parameters infinicore::Size k_dim,
*/ infinicore::Size v_dim,
DynamicCache(const CacheConfig &cache_config) infinicore::Size num_k_heads,
: cache_config_(cache_config), layers_(cache_config.num_layers) { infinicore::Size num_v_heads,
if (cache_config.num_layers == -1) { infinicore::Size num_layers,
throw std::runtime_error("DynamicCache: num_layers must be specified in CacheConfig"); infinicore::Size max_positional_embedding,
} infinicore::DataType dtype,
} const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);
/** /**
* @brief Construct DynamicCache with specified number of layers * @brief Update KV cache at a given layer and cache position.
* *
* @param num_layers Number of model layers (creates one cache layer per model layer) * @param layer_idx Which transformer layer
* @param max_position_embeddings Maximum position embeddings (used for initial capacity) * @param k [batch, num_rank_k_heads, seq_len, k_dim]
*/ * @param v [batch, num_rank_v_heads, seq_len, v_dim]
DynamicCache(size_t num_layers, size_t max_position_embeddings = 4096) * @param cache_pos Sequence position to write
: cache_config_(CacheConfig(CacheType::DYNAMIC, num_layers, max_position_embeddings)), layers_(num_layers) {}
/**
* @brief Update cache with new key and value states for a specific layer
*/
std::pair<infinicore::Tensor, infinicore::Tensor> update(
size_t layer_idx,
const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new) {
if (layer_idx >= layers_.size()) {
SPDLOG_ERROR("DynamicCache::update: layer_idx {} out of range (num_layers: {})",
layer_idx, layers_.size());
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
// Update the cache for this layer with cache config
return layers_[layer_idx].update(k_new, v_new, cache_config_);
}
/**
* @brief Update cache with new key and value states (convenience method without layer_idx)
* This is used when the cache is accessed directly without layer information
* *
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim] * @return (full_k, full_v)
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim] * full_k: [batch, num_rank_k_heads, cache_pos + seq_len, k_dim]
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim] * full_v: [batch, num_rank_v_heads, cache_pos + seq_len, v_dim]
*
* Note: This assumes layer_idx=0. For multi-layer models, use update(layer_idx, k_new, v_new) instead.
*/
std::pair<infinicore::Tensor, infinicore::Tensor> update(
const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new) {
return update(0, k_new, v_new);
}
/**
* @brief Get cache configuration
*/
const CacheConfig &get_config() const { return cache_config_; }
/**
* @brief Update cache configuration (for dynamic reconfiguration)
*/
void update_config(const CacheConfig &new_config) {
// Check if we need to rebuild
bool need_rebuild = false;
// Rebuild if number of layers changed
if (new_config.num_layers != cache_config_.num_layers || new_config.initial_batch_size != cache_config_.initial_batch_size) {
need_rebuild = true;
layers_.resize(new_config.num_layers);
}
// Rebuild if reset mode is RECREATE
if (new_config.reset_mode == CacheResetMode::RECREATE) {
need_rebuild = true;
}
// Update configuration
cache_config_ = new_config;
if (need_rebuild) {
// Clear all layers to force reinitialization on next use
for (auto &layer : layers_) {
layer.initialized = false;
layer.max_capacity = 0;
// Tensors will be recreated when ensure_capacity is called
}
spdlog::info("DynamicCache configuration updated - cache will be rebuilt on next use");
} else {
spdlog::info("DynamicCache configuration updated: layers={}, initial_capacity={}, growth_factor={}",
new_config.num_layers, new_config.initial_capacity, new_config.growth_factor);
}
}
/**
* @brief Get the number of layers in this cache
*/
size_t num_layers() const { return layers_.size(); }
/**
* @brief Get cache position for a specific layer
*/
size_t cache_position(size_t layer_idx) const {
if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
if (layers_[layer_idx].cache_positions.empty()) {
return 0;
}
return layers_[layer_idx].cache_positions[0]; // All batch items should have same position
}
/**
* @brief Get max position embeddings (used for initial capacity)
*/
size_t max_kv_cache_length() const { return cache_config_.max_kv_cache_length; }
/**
* @brief Reset cache for all layers to a specific position
* This should be called when starting a new generation sequence or resetting to a specific position
* @param pos Position to reset to (defaults to 0)
*/
void reset(size_t pos = 0) {
for (auto &layer : layers_) {
std::fill(layer.cache_positions.begin(), layer.cache_positions.end(), pos);
// Note: We don't reset initialized flag or clear the cache tensors
// to avoid reallocation. The cache will be overwritten on next update.
}
}
/**
* @brief Access a specific layer's cache (for advanced usage)
*/ */
KVCacheLayer &layer(size_t layer_idx) { std::tuple<infinicore::Tensor, infinicore::Tensor>
if (layer_idx >= layers_.size()) { update(size_t layer_idx,
throw std::runtime_error("DynamicCache: layer_idx out of range"); const infinicore::Tensor &k,
} const infinicore::Tensor &v,
return layers_[layer_idx]; const infinicore::Tensor &cache_positions);
}
const KVCacheLayer &layer(size_t layer_idx) const { ~StaticKVCache() override = default;
if (layer_idx >= layers_.size()) {
throw std::runtime_error("DynamicCache: layer_idx out of range");
}
return layers_[layer_idx];
}
private: private:
CacheConfig cache_config_; infinicore::Size k_dim_;
std::vector<KVCacheLayer> layers_; infinicore::Size v_dim_;
infinicore::Size num_rank_k_heads_;
infinicore::Size num_rank_v_heads_;
infinicore::Size rank_batch_size_;
infinicore::Size cache_len_;
infinicore::Size rank_num_layers_;
infinicore::DataType dtype_;
// [num_layers, max_batch, num_rank_k_heads, max_cache_len, k_dim]
infinicore::Tensor k_caches_;
// [num_layers, max_batch, num_rank_v_heads, max_cache_len, v_dim]
infinicore::Tensor v_caches_;
}; };
} // namespace infinilm::cache } // namespace infinilm::cache
...@@ -38,7 +38,7 @@ int CommunicationGroup::get_world_size() const { ...@@ -38,7 +38,7 @@ int CommunicationGroup::get_world_size() const {
CommunicationGroup::~CommunicationGroup() { CommunicationGroup::~CommunicationGroup() {
if (communicators_.size() > 1) { if (communicators_.size() > 1) {
for (auto &comm : communicators_) { for (auto &comm : communicators_) {
RUN_INFINI(infinicclCommDestroy(comm)); infinicclCommDestroy(comm);
} }
} }
} }
......
...@@ -10,32 +10,13 @@ InferEngine::InferEngine( ...@@ -10,32 +10,13 @@ InferEngine::InferEngine(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config, const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type, infinicore::Device::Type device_type,
const cache::CacheConfig &cache_config) // Changed parameter const cache::CacheConfig *cache_config) // Changed parameter
: communication_group_(distributed_config, device_type), : communication_group_(distributed_config, device_type),
model_config_(config), model_config_(config) {
cache_config_(cache_config) {
spdlog::info("Launch InferEngine with {}", std::string(distributed_config)); if (cache_config != nullptr) {
spdlog::info("Cache configuration: type={}, layers={}, max_kv_cache_length={}", cache_config_ = cache_config->unique_copy();
static_cast<int>(cache_config_.type),
cache_config_.num_layers,
cache_config_.max_kv_cache_length);
// Try to extract model configuration to override default cache parameters if needed
try {
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr;
cache_config_.num_layers = llama_config.num_hidden_layers;
cache_config_.max_kv_cache_length = llama_config.max_position_embeddings;
spdlog::info("Updated cache config from model: layers={}, max_kv_cache_length={}",
cache_config_.num_layers, cache_config_.max_kv_cache_length);
}
} catch (...) {
spdlog::warn("Could not extract model config, using provided CacheConfig");
} }
// Create one RankWorker per rank // Create one RankWorker per rank
int world_size = communication_group_.get_world_size(); int world_size = communication_group_.get_world_size();
workers_.reserve(world_size); workers_.reserve(world_size);
...@@ -43,7 +24,7 @@ InferEngine::InferEngine( ...@@ -43,7 +24,7 @@ InferEngine::InferEngine(
workers_.emplace_back(std::make_unique<RankWorker>( workers_.emplace_back(std::make_unique<RankWorker>(
model_config_, model_config_,
communication_group_.get_rank_info(r), communication_group_.get_rank_info(r),
cache_config_)); cache_config_ != nullptr ? cache_config_.get() : nullptr));
} }
} }
...@@ -75,12 +56,14 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng ...@@ -75,12 +56,14 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------ //------------------------------------------------------
// forward // forward
//------------------------------------------------------ //------------------------------------------------------
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const {
const auto &[input_ids, position_ids] = input; return {input_ids, position_ids, cache_positions};
}
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
// Trigger each worker to run inference // Trigger each worker to run inference
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->run({input_ids, position_ids}); worker->run(input.to_model_input());
} }
// Wait for all workers // Wait for all workers
for (auto &worker : workers_) { for (auto &worker : workers_) {
...@@ -104,25 +87,12 @@ const distributed::DistConfig &InferEngine::get_dist_config() const { ...@@ -104,25 +87,12 @@ const distributed::DistConfig &InferEngine::get_dist_config() const {
return communication_group_.get_dist_config(); return communication_group_.get_dist_config();
} }
//------------------------------------------------------
// reset_cache
//------------------------------------------------------
void InferEngine::reset_cache(size_t pos) {
for (auto &worker : workers_) {
worker->reset_cache(pos);
}
for (auto &worker : workers_) {
worker->wait();
}
}
//------------------------------------------------------ //------------------------------------------------------
// reset_cache (overloaded with CacheConfig) // reset_cache (overloaded with CacheConfig)
//------------------------------------------------------ //------------------------------------------------------
void InferEngine::reset_cache(const cache::CacheConfig &new_config, size_t pos) { void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
cache_config_ = new_config;
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->reset_cache(new_config, pos); worker->reset_cache(new_config);
} }
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->wait(); worker->wait();
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#include "distributed/distributed.hpp" #include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "rank_worker.hpp" #include "rank_worker.hpp"
#include "../models/infinilm_model.hpp"
#include <any>
#include <vector> #include <vector>
namespace infinilm::engine { namespace infinilm::engine {
...@@ -16,6 +16,10 @@ public: ...@@ -16,6 +16,10 @@ public:
infinicore::Tensor input_ids; infinicore::Tensor input_ids;
infinicore::Tensor position_ids; infinicore::Tensor position_ids;
infinicore::Tensor cache_positions;
infinilm::InfinilmModel::Input to_model_input() const;
}; };
struct Output { struct Output {
...@@ -27,7 +31,7 @@ public: ...@@ -27,7 +31,7 @@ public:
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(), const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig &cache_config = cache::CacheConfig()); const cache::CacheConfig *cache_config = nullptr);
// Load a parameter to all workers (each can extract its shard inside RankWorker) // Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param); void load_param(const std::string &name, const infinicore::Tensor &param);
...@@ -38,24 +42,20 @@ public: ...@@ -38,24 +42,20 @@ public:
// Run a single forward pass on all workers and return the outputs from all ranks // Run a single forward pass on all workers and return the outputs from all ranks
Output forward(const Input &input); Output forward(const Input &input);
// Reset the internal cache pos in all workers (clears state between generations) void reset_cache(const cache::CacheConfig *new_config);
void reset_cache(size_t pos = 0);
// Overload: reset cache with new KV configuration
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0);
~InferEngine(); ~InferEngine();
const distributed::DistConfig &get_dist_config() const; const distributed::DistConfig &get_dist_config() const;
// Get current KV configuration // Get current KV configuration
const cache::CacheConfig &get_cache_config() const { return cache_config_; } const cache::CacheConfig *get_cache_config() const { return cache_config_.get(); }
protected: protected:
std::vector<std::unique_ptr<RankWorker>> workers_; std::vector<std::unique_ptr<RankWorker>> workers_;
distributed::CommunicationGroup communication_group_; distributed::CommunicationGroup communication_group_;
const InfinilmModel::Config &model_config_; const InfinilmModel::Config &model_config_;
cache::CacheConfig cache_config_; std::unique_ptr<cache::CacheConfig> cache_config_;
}; };
} // namespace infinilm::engine } // namespace infinilm::engine
...@@ -10,15 +10,17 @@ namespace infinilm::engine { ...@@ -10,15 +10,17 @@ namespace infinilm::engine {
RankWorker::RankWorker(const InfinilmModel::Config &model_config, RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config) const cache::CacheConfig *cache_config)
: model_config_(model_config), : model_config_(model_config),
rank_info_(rank_info), rank_info_(rank_info),
job_cmd_(Command::INIT), job_cmd_(Command::INIT),
has_job_(false), has_job_(false),
job_done_(false), job_done_(false),
should_exit_(false), should_exit_(false),
init_done_(false), init_done_(false) {
pending_cache_config_(cache_config) { if (cache_config != nullptr) {
pending_cache_config_ = cache_config->unique_copy();
}
// start the thread // start the thread
thread_ = std::thread(&RankWorker::thread_loop, this); thread_ = std::thread(&RankWorker::thread_loop, this);
...@@ -80,7 +82,14 @@ void RankWorker::load_param(const std::string &name, ...@@ -80,7 +82,14 @@ void RankWorker::load_param(const std::string &name,
// state_dict -- // state_dict --
//------------------------------------------------------ //------------------------------------------------------
std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dict() { std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dict() {
return this->model_->state_dict(); std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_ || should_exit_; });
if (!model_) {
throw std::runtime_error("state_dict called before model initialization");
}
return model_->state_dict();
} }
//------------------------------------------------------ //------------------------------------------------------
...@@ -113,32 +122,15 @@ void RankWorker::wait() { ...@@ -113,32 +122,15 @@ void RankWorker::wait() {
} }
} }
//------------------------------------------------------ void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
// reset_cache -- synchronous by default, async optional (unstable)
//------------------------------------------------------
void RankWorker::reset_cache(size_t pos) {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot reset_cache");
}
pending_reset_pos_ = pos;
job_cmd_ = Command::RESET_CACHE;
has_job_ = true;
job_done_ = false;
cv_.notify_all();
}
void RankWorker::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) { if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot reset_cache"); throw std::runtime_error("RankWorker is closing; cannot reset_cache");
} }
// Store both the position and the new config // Store both the position and the new config
pending_reset_pos_ = pos; pending_cache_config_ = new_config->unique_copy();
pending_cache_config_ = new_config; job_cmd_ = Command::RESET_CACHE;
job_cmd_ = Command::RESET_CACHE_WITH_CONFIG;
has_job_ = true; has_job_ = true;
job_done_ = false; job_done_ = false;
cv_.notify_all(); cv_.notify_all();
...@@ -174,17 +166,17 @@ InfinilmModel::Output RankWorker::get_output() { ...@@ -174,17 +166,17 @@ InfinilmModel::Output RankWorker::get_output() {
//------------------------------------------------------ //------------------------------------------------------
void RankWorker::thread_loop() { void RankWorker::thread_loop() {
try { try {
// Initialize device & model outside of holding the main mutex to avoid blocking callers.
infinicore::context::setDevice(rank_info_.device);
cache_ptr_ = std::make_shared<cache::DynamicCache>(pending_cache_config_);
// Create model using factory (may be expensive)
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, cache_ptr_);
// Signal that initialization is done
{ {
std::lock_guard<std::mutex> lk(mutex_); std::lock_guard<std::mutex> lk(mutex_);
// Initialize device & model outside of holding the main mutex to avoid blocking callers.
infinicore::context::setDevice(rank_info_.device);
// Create model using factory (may be expensive)
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
if (!model_) {
throw std::runtime_error("Failed to create model");
}
init_done_ = true; init_done_ = true;
} }
cv_.notify_all(); cv_.notify_all();
...@@ -195,8 +187,7 @@ void RankWorker::thread_loop() { ...@@ -195,8 +187,7 @@ void RankWorker::thread_loop() {
std::string local_param_name; std::string local_param_name;
infinicore::Tensor local_param; infinicore::Tensor local_param;
InfinilmModel::Input local_args; InfinilmModel::Input local_args;
size_t local_reset_pos = 0; std::unique_ptr<cache::CacheConfig> local_cache_config;
cache::CacheConfig local_reset_config;
// Wait for a job or exit // Wait for a job or exit
{ {
...@@ -215,12 +206,10 @@ void RankWorker::thread_loop() { ...@@ -215,12 +206,10 @@ void RankWorker::thread_loop() {
} else if (local_cmd == Command::RUN) { } else if (local_cmd == Command::RUN) {
local_args = pending_args_; local_args = pending_args_;
} else if (local_cmd == Command::RESET_CACHE) { } else if (local_cmd == Command::RESET_CACHE) {
local_reset_pos = pending_reset_pos_; if (pending_cache_config_ != nullptr) {
} else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) { local_cache_config = pending_cache_config_->unique_copy();
local_reset_pos = pending_reset_pos_; }
local_reset_config = pending_cache_config_;
} }
// mark job as being processed // mark job as being processed
has_job_ = false; has_job_ = false;
job_done_ = false; job_done_ = false;
...@@ -270,14 +259,7 @@ void RankWorker::thread_loop() { ...@@ -270,14 +259,7 @@ void RankWorker::thread_loop() {
} }
} else if (local_cmd == Command::RESET_CACHE) { } else if (local_cmd == Command::RESET_CACHE) {
try { try {
// Option 1: Use model's reset_cache if it handles cache model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr);
model_->reset_cache(local_reset_pos);
// Option 2: Reset cache directly if we have access
// if (cache_ptr_ != nullptr) {
// auto* dynamic_cache = static_cast<cache::DynamicCache*>(cache_ptr_);
// dynamic_cache->reset(local_reset_pos);
// }
{ {
std::lock_guard<std::mutex> lk(mutex_); std::lock_guard<std::mutex> lk(mutex_);
...@@ -293,25 +275,6 @@ void RankWorker::thread_loop() { ...@@ -293,25 +275,6 @@ void RankWorker::thread_loop() {
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what()); spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
break; break;
} }
} else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) {
try {
// Use model's reset_cache with new configuration
model_->reset_cache(local_reset_config, local_reset_pos);
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
}
cv_.notify_all();
} catch (const std::exception &e) {
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
cv_.notify_all();
spdlog::error("[{}] exception during reset_cache with config: {}\n", info(), e.what());
break;
}
} else { } else {
// Shouldn't reach here (no-op) // Shouldn't reach here (no-op)
} }
......
...@@ -19,14 +19,13 @@ class RankWorker { ...@@ -19,14 +19,13 @@ class RankWorker {
LOAD, LOAD,
RUN, RUN,
RESET_CACHE, RESET_CACHE,
RESET_CACHE_WITH_CONFIG,
STOP STOP
}; };
public: public:
RankWorker(const InfinilmModel::Config &model_config, RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config); const cache::CacheConfig *cache_config);
// Submit a parameter load job and wait until the load completes on the worker thread. // Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name, void load_param(const std::string &name,
...@@ -38,11 +37,8 @@ public: ...@@ -38,11 +37,8 @@ public:
// Submit a run (forward) job. // Submit a run (forward) job.
void run(const InfinilmModel::Input &args); void run(const InfinilmModel::Input &args);
// Reset the internal cache in the model (clears state between generations)
void reset_cache(size_t pos = 0);
// Reset the internal cache with a new configuration // Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0); void reset_cache(const cache::CacheConfig *new_config);
// Wait until run job completes. The result can be retrieved with get_output(). // Wait until run job completes. The result can be retrieved with get_output().
void wait(); void wait();
...@@ -63,7 +59,7 @@ private: ...@@ -63,7 +59,7 @@ private:
const InfinilmModel::Config &model_config_; const InfinilmModel::Config &model_config_;
distributed::RankInfo rank_info_; distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_; std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::DynamicCache> cache_ptr_; std::shared_ptr<cache::Cache> cache_;
// Command for the pending job (protected by mutex_) // Command for the pending job (protected by mutex_)
Command job_cmd_; Command job_cmd_;
...@@ -78,8 +74,7 @@ private: ...@@ -78,8 +74,7 @@ private:
std::string pending_param_name_; std::string pending_param_name_;
infinicore::Tensor pending_param_; infinicore::Tensor pending_param_;
InfinilmModel::Input pending_args_; InfinilmModel::Input pending_args_;
size_t pending_reset_pos_ = 0; std::unique_ptr<cache::CacheConfig> pending_cache_config_;
cache::CacheConfig pending_cache_config_;
// Output (protected by mutex) // Output (protected by mutex)
InfinilmModel::Output output_; InfinilmModel::Output output_;
......
...@@ -18,12 +18,10 @@ public: ...@@ -18,12 +18,10 @@ public:
struct Input { struct Input {
/// Token IDs tensor of shape `[batch, seq_len]`. /// Token IDs tensor of shape `[batch, seq_len]`.
infinicore::Tensor input_ids; infinicore::Tensor input_ids;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`. /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
infinicore::Tensor position_ids; infinicore::Tensor position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
/// Optional model-level KV cache for incremental decoding. Defaults to `nullptr`. infinicore::Tensor cache_positions;
void *kv_cache = nullptr;
}; };
struct Output { struct Output {
...@@ -33,8 +31,7 @@ public: ...@@ -33,8 +31,7 @@ public:
virtual ~InfinilmModel() = default; virtual ~InfinilmModel() = default;
virtual Output forward(const Input &input) const = 0; virtual Output forward(const Input &input) const = 0;
// Optional: reset cache; default no-op for models without cache
virtual void reset_cache(size_t pos = 0) {} virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
virtual void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) = 0;
}; };
} // namespace infinilm } // namespace infinilm
...@@ -51,7 +51,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -51,7 +51,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
void *kv_cache) const { std::shared_ptr<cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const {
if (!rotary_emb_) { if (!rotary_emb_) {
throw std::runtime_error("LlamaAttention: rotary_emb not configured"); throw std::runtime_error("LlamaAttention: rotary_emb not configured");
} }
...@@ -97,16 +98,15 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -97,16 +98,15 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim] q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim]
auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
infinilm::cache::DynamicCache *external_cache = static_cast<infinilm::cache::DynamicCache *>(kv_cache); infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim] if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) {
if (external_cache != nullptr) { auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_positions);
auto [k_total_tmp, v_total_tmp] = external_cache->update(layer_idx_, k_permuted, v_permuted);
k_total = k_total_tmp; k_total = k_total_tmp;
v_total = v_total_tmp; v_total = v_total_tmp;
} else { } else {
// No external cache - this shouldn't happen in normal operation, but handle gracefully
throw std::runtime_error("LlamaAttention: kv_cache is required but nullptr provided"); throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
} }
auto total_seq_len = k_total->shape()[2]; auto total_seq_len = k_total->shape()[2];
......
...@@ -50,7 +50,8 @@ public: ...@@ -50,7 +50,8 @@ public:
*/ */
infinicore::Tensor forward(const infinicore::Tensor &hidden_states, infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const; std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const;
/** /**
* @brief Get the layer index * @brief Get the layer index
......
...@@ -23,7 +23,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, ...@@ -23,7 +23,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
void *kv_cache) const { std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const {
// Save residual for attention // Save residual for attention
auto residual = hidden_states; auto residual = hidden_states;
...@@ -31,7 +32,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s ...@@ -31,7 +32,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
auto normed_states = input_layernorm_->forward(hidden_states); auto normed_states = input_layernorm_->forward(hidden_states);
// 2. Self-attention with residual connection // 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache); auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_positions);
// Add residual: hidden_states = hidden_states + attn_output // Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output); auto output = infinicore::op::add(residual, attn_output);
......
...@@ -48,7 +48,8 @@ public: ...@@ -48,7 +48,8 @@ public:
*/ */
infinicore::Tensor forward(const infinicore::Tensor &hidden_states, infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const; std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const;
/** /**
* @brief Get the layer index * @brief Get the layer index
......
...@@ -26,11 +26,11 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, ...@@ -26,11 +26,11 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
} }
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
const auto &[input_ids, position_ids, kv_cache] = input; const auto &[input_ids, position_ids, cache_position] = input;
// 1. Forward through base model to get hidden states // 1. Forward through base model to get hidden states
auto position_ids_device = position_ids->to(device_); auto position_ids_device = position_ids->to(device_);
auto hidden_states = model_->forward(input_ids, position_ids_device, kv_cache); auto hidden_states = model_->forward(input_ids, position_ids_device, cache_position);
// 2. Apply language modeling head to get logits // 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states); auto logits = lm_head_->forward(hidden_states);
...@@ -38,12 +38,8 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { ...@@ -38,12 +38,8 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
return {logits}; return {logits};
} }
void LlamaForCausalLM::reset_cache(size_t pos) { void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) {
model_->reset_cache(pos); model_->reset_cache(cache_config);
}
void LlamaForCausalLM::reset_cache(const cache::CacheConfig &new_config, size_t pos) {
model_->reset_cache(new_config, pos);
} }
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -40,9 +40,7 @@ public: ...@@ -40,9 +40,7 @@ public:
*/ */
Output forward(const Input &input) const; Output forward(const Input &input) const;
// Reset internal cache position void reset_cache(const cache::CacheConfig *cache_config) override;
void reset_cache(size_t pos = 0) override;
void reset_cache(const cache::CacheConfig &new_config, size_t pos) override;
// Module information // Module information
const LlamaConfig &config() const { return model_->config(); } const LlamaConfig &config() const { return model_->config(); }
......
...@@ -10,9 +10,8 @@ namespace infinilm::models::llama { ...@@ -10,9 +10,8 @@ namespace infinilm::models::llama {
LlamaModel::LlamaModel(const LlamaConfig &config, LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
: config_(config) { : config_(config), rank_info_(rank_info) {
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
// Initialize token embeddings // Initialize token embeddings
INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size,
std::nullopt, dtype, device); std::nullopt, dtype, device);
...@@ -46,72 +45,46 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -46,72 +45,46 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
void *kv_cache) const { const infinicore::Tensor &cache_positions) const {
// Use persistent internal cache if no external cache is provided
// This matches Python backend behavior: if use_cache and past_key_values is None, create DynamicCache
// The cache persists across forward calls to enable incremental decoding
void *cache_to_use = kv_cache;
if (cache_to_use == nullptr) {
// Create or reuse persistent internal cache at model level
// This ensures the cache persists across multiple forward calls (prefill -> decode -> decode...)
if (external_cache_ != nullptr) {
cache_to_use = external_cache_;
} else {
// Fall back to internal cache
if (!internal_cache_) {
internal_cache_ = std::make_unique<infinilm::cache::DynamicCache>(
config_.num_hidden_layers,
config_.max_position_embeddings);
}
cache_to_use = internal_cache_.get();
}
}
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
auto hidden_states = embed_tokens_->forward(input_ids); auto hidden_states = embed_tokens_->forward(input_ids);
// 2. Process through all decoder layers // 2. Process through all decoder layers
size_t num_layers = layers_.size(); size_t num_layers = layers_.size();
for (size_t i = 0; i < num_layers; ++i) { for (size_t i = 0; i < num_layers; ++i) {
// Pass model-level cache (layer index is now a property of the layer) hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_positions);
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, cache_to_use);
// DEBUG: Disabled previous final layer logging
// Logging moved to decoder layer for post-attention normalization
} }
// 3. Apply final layer normalization to last token only (aligns with transformers) // 3. Apply final layer normalization to last token only (aligns with transformers)
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size] // Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
auto shape = hidden_states->shape(); auto shape = hidden_states->shape();
size_t seq_len = shape[1]; size_t seq_len = shape[1];
auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}}); auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}});
// DEBUG: Disabled previous final layer normalization logging
// Normalize only the last token (matches Python backend)
auto normalized_last_token = norm_->forward(last_token); auto normalized_last_token = norm_->forward(last_token);
return normalized_last_token; return normalized_last_token;
} }
void LlamaModel::reset_cache(size_t pos) const { void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
if (internal_cache_) { if (cache_config == nullptr) {
internal_cache_->reset(pos); kv_cache_ = nullptr;
} return;
if (external_cache_) {
external_cache_->reset(pos);
}
}
void LlamaModel::reset_cache(const cache::CacheConfig &new_config, size_t pos) const {
if (internal_cache_) {
internal_cache_->update_config(new_config);
internal_cache_->reset(pos);
} }
if (external_cache_) { if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
external_cache_->update_config(new_config); kv_cache_ = std::make_shared<cache::StaticKVCache>(
external_cache_->reset(pos); config_.head_dim,
config_.head_dim,
config_.num_key_value_heads,
config_.num_key_value_heads,
config_.num_hidden_layers,
config_.max_position_embeddings,
config_.dtype,
*kv_cache_config,
rank_info_);
} else {
throw std::runtime_error("Unsupported cache type");
} }
} }
......
...@@ -47,41 +47,19 @@ public: ...@@ -47,41 +47,19 @@ public:
* *
* @param input_ids Token IDs tensor of shape [batch, seq_len] * @param input_ids Token IDs tensor of shape [batch, seq_len]
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional model-level KV cache for incremental decoding * @param cache_positions Cache positions tensor of shape [n_req]
* @return Output tensor of shape [batch, seq_len, hidden_size] * @return Output tensor of shape [batch, seq_len, hidden_size]
*/ */
infinicore::Tensor forward(const infinicore::Tensor &input_ids, infinicore::Tensor forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const; const infinicore::Tensor &cache_positions) const;
void reset_cache(const cache::CacheConfig *cache_config);
// Module information // Module information
const LlamaConfig &config() const { return config_; } const LlamaConfig &config() const { return config_; }
size_t num_layers() const { return config_.num_hidden_layers; } size_t num_layers() const { return config_.num_hidden_layers; }
/**
* @brief Reset the internal cache to a specific position
* This should be called when starting a new generation sequence to prevent state
* from persisting between different questions/prompts
* @param pos Position to reset to (defaults to 0)
*/
void reset_cache(size_t pos = 0) const;
/**
* @brief Reset the internal cache with a new configuration and position
* This should be called when changing cache parameters (e.g., initial capacity)
* @param new_config New cache configuration
* @param pos Position to reset to
*/
void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) const;
/**
* @brief Set external cache for the model
* @param cache Pointer to external cache (managed by CacheManager)
*/
void set_external_cache(std::shared_ptr<cache::DynamicCache> cache) {
external_cache_ = cache.get();
}
protected: protected:
// Token embeddings // Token embeddings
INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens);
...@@ -95,13 +73,12 @@ protected: ...@@ -95,13 +73,12 @@ protected:
// Rotary Position Embeddings (shared across all layers) // Rotary Position Embeddings (shared across all layers)
INFINICORE_NN_MODULE(infinicore::nn::RoPE, rotary_emb); INFINICORE_NN_MODULE(infinicore::nn::RoPE, rotary_emb);
engine::distributed::RankInfo rank_info_;
std::shared_ptr<cache::Cache> kv_cache_;
private: private:
LlamaConfig config_; LlamaConfig config_;
// Persistent cache for when no external cache is provided
// Mutable because it's not part of the model's learned parameters,
// but needs to persist across forward calls for incremental decoding
mutable std::unique_ptr<infinilm::cache::DynamicCache> internal_cache_;
cache::DynamicCache *external_cache_ = nullptr;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -5,20 +5,21 @@ namespace infinilm { ...@@ -5,20 +5,21 @@ namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info, engine::distributed::RankInfo rank_info,
std::shared_ptr<cache::DynamicCache> cache_ptr) { const cache::CacheConfig *cache) {
std::shared_ptr<InfinilmModel> model;
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) { if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr; const auto &llama_config = *llama_config_ptr;
auto model = std::make_shared<models::llama::LlamaForCausalLM>( model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, rank_info); llama_config, rank_info.device, rank_info);
if (cache_ptr != nullptr) {
model->model().set_external_cache(cache_ptr);
}
return model;
} else { } else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
} }
if (cache) {
model->reset_cache(cache);
}
return model;
} }
} // namespace infinilm } // namespace infinilm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment