Commit f147eb02 authored by PanZezhong's avatar PanZezhong
Browse files

issue/125 添加Paged KV Cache接口

parent 11007392
......@@ -5,6 +5,31 @@
#include <stdexcept>
namespace infinilm::cache {
// ==========================
// StaticKVCacheConfig
// ==========================
StaticKVCacheConfig::StaticKVCacheConfig(
infinicore::Size _max_batch_size,
infinicore::Size _max_cache_len)
: max_batch_size_(_max_batch_size),
max_cache_len_(_max_cache_len) {
}
std::unique_ptr<CacheConfig>
StaticKVCacheConfig::unique_copy() const {
return std::make_unique<StaticKVCacheConfig>(*this);
}
infinicore::Size
StaticKVCacheConfig::max_batch_size() const {
return max_batch_size_;
}
infinicore::Size
StaticKVCacheConfig::max_cache_len() const {
return max_cache_len_;
}
// ==========================
// StaticKVCache
......@@ -26,7 +51,7 @@ StaticKVCache::StaticKVCache(
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_batch_size_(config.max_batch_size()),
cache_len_(std::min(config.max_cache_len(), max_positional_embedding)),
cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
rank_num_layers_(num_layers),
dtype_(dtype) {
......@@ -49,20 +74,18 @@ StaticKVCache::StaticKVCache(
v_dim_},
dtype_,
rank_info.device);
spdlog::info("Created Static KV Cache: K[{}] V[{}]", k_caches_->info(), v_caches_->info());
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache::update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &cache_positions) {
const infinicore::Tensor &cache_lengths) {
ASSERT(layer_idx < rank_num_layers_);
auto batch_size = k->size(0);
auto update_len = k->size(2);
size_t cache_pos = reinterpret_cast<int64_t *>(cache_positions->to(infinicore::Device::cpu())->data())[0];
size_t cache_pos = reinterpret_cast<int64_t *>(cache_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
......@@ -85,29 +108,90 @@ StaticKVCache::update(size_t layer_idx,
}
// ==========================
// StaticKVCacheConfig
// PagedKVCacheConfig
// ==========================
StaticKVCacheConfig::StaticKVCacheConfig(
infinicore::Size _max_batch_size,
infinicore::Size _max_cache_len)
: max_batch_size_(_max_batch_size),
max_cache_len_(_max_cache_len) {
PagedKVCacheConfig::PagedKVCacheConfig(
size_t max_kv_memory_bytes,
size_t block_size)
: max_kv_memory_bytes_(max_kv_memory_bytes),
block_size_(block_size) {
}
std::unique_ptr<CacheConfig>
StaticKVCacheConfig::unique_copy() const {
return std::make_unique<StaticKVCacheConfig>(*this);
PagedKVCacheConfig::unique_copy() const {
return std::make_unique<PagedKVCacheConfig>(*this);
}
infinicore::Size
StaticKVCacheConfig::max_batch_size() const {
return max_batch_size_;
size_t
PagedKVCacheConfig::max_kv_memory_bytes() const {
return max_kv_memory_bytes_;
}
infinicore::Size
StaticKVCacheConfig::max_cache_len() const {
return max_cache_len_;
size_t
PagedKVCacheConfig::block_size() const {
return block_size_;
}
// ==========================
// PagedKVCache
// ==========================
PagedKVCache::PagedKVCache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
k_dim_(k_dim),
v_dim_(v_dim),
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_num_layers_(num_layers),
dtype_(dtype),
block_size_(config.block_size()) {
num_blocks_per_layer_ = config.max_kv_memory_bytes()
/ (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_)
/ block_size_
/ infinicore::dsize(dtype_);
if (num_blocks_per_layer_ == 0) {
throw std::runtime_error("Not enough memory for KV cache");
}
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
num_blocks_per_layer_,
num_rank_k_heads_,
block_size_,
k_dim_},
dtype_,
rank_info.device);
// [num_layers, num_blocks, num_rank_v_heads, block_size, v_dim]
v_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
num_blocks_per_layer_,
num_rank_v_heads_,
block_size_,
v_dim_},
dtype_,
rank_info.device);
}
std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping) {
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
/// @todo: implement paged cache update here
return {k_cache_layer, v_cache_layer};
}
} // namespace infinilm::cache
......@@ -61,7 +61,7 @@ public:
update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &cache_positions);
const infinicore::Tensor &cache_lengths);
~StaticKVCache() override = default;
......@@ -82,4 +82,68 @@ private:
infinicore::Tensor v_caches_;
};
class PagedKVCacheConfig final : public CacheConfig {
public:
PagedKVCacheConfig(
size_t max_kv_memory_bytes,
size_t block_size = 16);
std::unique_ptr<CacheConfig> unique_copy() const override;
size_t max_kv_memory_bytes() const;
size_t block_size() const;
private:
size_t max_kv_memory_bytes_;
size_t block_size_;
};
class PagedKVCache final : public Cache {
public:
PagedKVCache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);
/**
* @brief Update Paged KV cache at a given layer given slot info for each token.
*
* @param layer_idx Which transformer layer
* @param k [num_rank_k_heads, seq_len, k_dim]
* @param v [num_rank_v_heads, seq_len, v_dim]
* @param slot_mapping [seq_len]
*
* @return (full_k, full_v)
* full_k: [num_blocks, num_rank_k_heads, block_size, k_dim]
* full_v: [num_blocks, num_rank_v_heads, block_size, v_dim]
*/
std::tuple<infinicore::Tensor, infinicore::Tensor>
update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping);
~PagedKVCache() override = default;
private:
infinicore::Size k_dim_;
infinicore::Size v_dim_;
infinicore::Size num_rank_k_heads_;
infinicore::Size num_rank_v_heads_;
infinicore::Size rank_num_layers_;
infinicore::DataType dtype_;
infinicore::Size block_size_;
infinicore::Size num_blocks_per_layer_;
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
infinicore::Tensor k_caches_;
// [num_layers, num_blocks, num_rank_v_heads, block_size, v_dim]
infinicore::Tensor v_caches_;
};
} // namespace infinilm::cache
......@@ -57,7 +57,7 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
// forward
//------------------------------------------------------
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const {
return {input_ids, position_ids, cache_positions};
return {input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping};
}
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
......
#pragma once
#include "../models/infinilm_model.hpp"
#include "../models/llama/llama_config.hpp"
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"
#include "../models/infinilm_model.hpp"
#include <optional>
#include <vector>
namespace infinilm::engine {
......@@ -13,11 +14,20 @@ namespace infinilm::engine {
class InferEngine {
public:
struct Input {
infinicore::Tensor input_ids;
infinicore::Tensor position_ids;
infinicore::Tensor cache_positions;
/// Token IDs tensor of shape `[batch, seq_len]`.
std::optional<infinicore::Tensor> input_ids;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std::optional<infinicore::Tensor> position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std::optional<infinicore::Tensor> cache_lengths;
/// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_offsets;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
std::optional<infinicore::Tensor> slot_mapping;
infinilm::InfinilmModel::Input to_model_input() const;
};
......
......@@ -6,6 +6,8 @@
#include <any>
#include <optional>
namespace infinilm {
class InfinilmModel : public infinicore::nn::Module {
public:
......@@ -17,11 +19,19 @@ public:
struct Input {
/// Token IDs tensor of shape `[batch, seq_len]`.
infinicore::Tensor input_ids;
std::optional<infinicore::Tensor> input_ids;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
infinicore::Tensor position_ids;
std::optional<infinicore::Tensor> position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
infinicore::Tensor cache_positions;
std::optional<infinicore::Tensor> cache_lengths;
/// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_offsets;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
std::optional<infinicore::Tensor> slot_mapping;
};
struct Output {
......
#include "llama_attention.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mul.hpp"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <optional>
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <vector>
......@@ -52,7 +55,11 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const {
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) {
throw std::runtime_error("LlamaAttention: rotary_emb not configured");
}
......@@ -100,12 +107,21 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) {
auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_positions);
if (kv_cache == nullptr) {
k_total = k_permuted;
v_total = v_permuted;
} else if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) {
auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_lengths.value());
k_total = k_total_tmp;
v_total = v_total_tmp;
} else if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
auto [k_total_tmp, v_total_tmp] = paged_kv_cache->update(layer_idx_, k_permuted, v_permuted, slot_mapping.value());
k_total = k_total_tmp;
v_total = v_total_tmp;
} else {
/// @todo Implement paged attention here.
throw std::runtime_error("LlamaAttention: Paged attention not implemented");
} else {
throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
}
auto total_seq_len = k_total->shape()[2];
......
......@@ -51,7 +51,11 @@ public:
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const;
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
/**
* @brief Get the layer index
......
......@@ -2,6 +2,8 @@
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/ops.hpp"
#include <optional>
namespace infinilm::models::llama {
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
......@@ -24,7 +26,11 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const {
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// Save residual for attention
auto residual = hidden_states;
......@@ -32,7 +38,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
auto normed_states = input_layernorm_->forward(hidden_states);
// 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_positions);
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
// Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output);
......
......@@ -49,7 +49,11 @@ public:
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
const infinicore::Tensor &cache_positions) const;
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
/**
* @brief Get the layer index
......
......@@ -26,11 +26,17 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
}
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
const auto &[input_ids, position_ids, cache_position] = input;
auto input_ids = input.input_ids.value();
auto position_ids = input.position_ids.value();
auto cache_lengths = input.cache_lengths;
auto input_lengths = input.input_lengths;
auto input_offsets = input.input_offsets;
auto block_tables = input.block_tables;
auto slot_mapping = input.slot_mapping;
// 1. Forward through base model to get hidden states
auto position_ids_device = position_ids->to(device_);
auto hidden_states = model_->forward(input_ids, position_ids_device, cache_position);
auto hidden_states = model_->forward(input_ids, position_ids_device, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
......
......@@ -45,14 +45,18 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
const infinicore::Tensor &cache_positions) const {
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
auto hidden_states = embed_tokens_->forward(input_ids);
// 2. Process through all decoder layers
size_t num_layers = layers_.size();
for (size_t i = 0; i < num_layers; ++i) {
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_positions);
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
}
// 3. Apply final layer normalization to last token only (aligns with transformers)
......@@ -83,6 +87,16 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
*kv_cache_config,
rank_info_);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) {
kv_cache_ = std::make_shared<cache::PagedKVCache>(
config_.head_dim,
config_.head_dim,
config_.num_key_value_heads,
config_.num_key_value_heads,
config_.num_hidden_layers,
config_.dtype,
*paged_kv_cache_config,
rank_info_);
} else {
throw std::runtime_error("Unsupported cache type");
}
......
......@@ -45,14 +45,21 @@ public:
/**
* @brief Forward pass: process input through the model
*
* @param input_ids Token IDs tensor of shape [batch, seq_len]
* @param input_ids Token IDs tensor of shape [batch, seq_len]. Batch is 1 when continuous batch is used,
* and tokens from all requests are concatenated along seq_len dimension.
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param cache_positions Cache positions tensor of shape [n_req]
* @param cache_lengths Cache positions tensor of shape [n_req]
* @param input_lengths Input lengths tensor in a continuous batch of shape [n_req]
* @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req]
* @return Output tensor of shape [batch, seq_len, hidden_size]
*/
infinicore::Tensor forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
const infinicore::Tensor &cache_positions) const;
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
void reset_cache(const cache::CacheConfig *cache_config);
......
......@@ -46,9 +46,9 @@ inline void bind_infer_engine(py::module &m) {
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none())
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) {
py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) {
......@@ -76,10 +76,36 @@ inline void bind_infer_engine(py::module &m) {
});
py::class_<InferEngine::Input>(infer_engine, "Input")
.def(py::init([](const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids, const infinicore::Tensor &cache_positions) {
return new InferEngine::Input{input_ids, position_ids, cache_positions};
}),
py::arg("input_ids"), py::arg("position_ids"), py::arg("cache_positions"));
.def(
py::init([](
std::optional<infinicore::Tensor> input_ids,
std::optional<infinicore::Tensor> position_ids,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) {
return InferEngine::Input{
std::move(input_ids),
std::move(position_ids),
std::move(cache_lengths),
std::move(block_tables),
std::move(slot_mapping)};
}),
py::arg("input_ids") = std::nullopt,
py::arg("position_ids") = std::nullopt,
py::arg("cache_lengths") = std::nullopt,
py::arg("input_lengths") = std::nullopt,
py::arg("input_offsets") = std::nullopt,
py::arg("block_tables") = std::nullopt,
py::arg("slot_mapping") = std::nullopt)
.def_readwrite("input_ids", &InferEngine::Input::input_ids)
.def_readwrite("position_ids", &InferEngine::Input::position_ids)
.def_readwrite("cache_lengths", &InferEngine::Input::cache_lengths)
.def_readwrite("input_lengths", &InferEngine::Input::input_lengths)
.def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
.def_readwrite("block_tables", &InferEngine::Input::block_tables)
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping);
py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("logits", &InferEngine::Output::logits, "Output tensor");
......
......@@ -172,6 +172,11 @@ def get_args():
default=20,
help="output tokens",
)
parser.add_argument(
"--skip-load",
action="store_true",
help="skip loading model weights",
)
return parser.parse_args()
......@@ -194,6 +199,7 @@ class TestModel:
model_path,
infini_device=infinicore.device("cpu", 0),
tp=1,
skip_load=False,
) -> None:
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -209,7 +215,8 @@ class TestModel:
# ---------------------------------------------------------------------------- #
# 加载权重
# ---------------------------------------------------------------------------- #
load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype)
if not skip_load:
load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype)
# ---------------------------------------------------------------------------- #
# 创建 tokenizer
......@@ -289,6 +296,8 @@ if __name__ == "__main__":
tp = args.tensor_parallel_size
skip_load = args.skip_load
batch_size = args.batch_size
input_len = args.input_len
output_len = args.output_len
......@@ -312,6 +321,7 @@ if __name__ == "__main__":
model_path,
infini_device=infini_device,
tp=tp,
skip_load=skip_load,
)
for idx, case in tqdm(cases_dict.items(), desc="Processing cases"):
......@@ -322,10 +332,8 @@ if __name__ == "__main__":
output_len = case["output_len"]
# reset cache for each case
initial_capacity = input_len + output_len + 100
test.model.reset_cache(
batch_size=batch_size, initial_capacity=initial_capacity
)
initial_capacity = input_len + output_len
test.model.reset_cache(batch_size=batch_size, initial_capacity=initial_capacity)
# run test one case
test.run(
......
......@@ -9,5 +9,5 @@ class CacheConfig(_infinilm.CacheConfig):
class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
def __init__(self, max_batch_size: int = 1, max_cache_len: int = 4096):
def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0):
_infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len)
......@@ -10,6 +10,7 @@ import infinilm
from infinilm.models.llama import AutoLlamaModel
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.cache import StaticKVCacheConfig
from abc import ABC, abstractmethod
......@@ -118,6 +119,7 @@ class InfiniLMBenchmark(BaseBenchmark):
device=self.device,
backend=backend,
distributed_config=DistConfig(ndev),
cache_config=StaticKVCacheConfig(),
)
# Enable KV cache for generation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment