Commit f246c4f1 authored by PanZezhong's avatar PanZezhong
Browse files

issue/168 InfiniLM接入paged attention

parent 96e53dbb
#include "kv_cache.hpp" #include "kv_cache.hpp"
#include "../utils.hpp" #include "../utils.hpp"
#include "infinicore/ops.hpp"
#include <stdexcept> #include <stdexcept>
namespace infinilm::cache { namespace infinilm::cache {
...@@ -155,6 +155,7 @@ PagedKVCache::PagedKVCache( ...@@ -155,6 +155,7 @@ PagedKVCache::PagedKVCache(
num_blocks_per_layer_ = config.max_kv_memory_bytes() num_blocks_per_layer_ = config.max_kv_memory_bytes()
/ (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_) / (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_)
/ block_size_ / block_size_
/ rank_num_layers_
/ infinicore::dsize(dtype_); / infinicore::dsize(dtype_);
if (num_blocks_per_layer_ == 0) { if (num_blocks_per_layer_ == 0) {
throw std::runtime_error("Not enough memory for KV cache"); throw std::runtime_error("Not enough memory for KV cache");
...@@ -190,8 +191,11 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update( ...@@ -190,8 +191,11 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
/// @todo: implement paged cache update here infinicore::op::paged_caching_(k,
v,
k_cache_layer,
v_cache_layer,
slot_mapping);
return {k_cache_layer, v_cache_layer}; return {k_cache_layer, v_cache_layer};
} }
} // namespace infinilm::cache } // namespace infinilm::cache
...@@ -56,8 +56,50 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng ...@@ -56,8 +56,50 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------ //------------------------------------------------------
// forward // forward
//------------------------------------------------------ //------------------------------------------------------
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const { infinilm::InfinilmModel::Input InferEngine::Input::to_model_input(infinicore::Device device) const {
return {input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping};
std::optional<infinicore::Tensor> position_ids_on_device;
if (position_ids.has_value()) {
position_ids_on_device = position_ids.value()->to(device);
}
std::optional<infinicore::Tensor> cache_lengths_on_device;
if (cache_lengths.has_value()) {
if (block_tables.has_value()) {
cache_lengths_on_device = cache_lengths.value()->to(device);
} else { // @todo: only paged kv cache support device tensor so far
cache_lengths_on_device = cache_lengths.value();
}
}
std::optional<infinicore::Tensor> input_lengths_on_device;
if (input_lengths.has_value()) {
input_lengths_on_device = input_lengths.value()->to(device);
}
std::optional<infinicore::Tensor> input_offsets_on_device;
if (input_offsets.has_value()) {
input_offsets_on_device = input_offsets.value()->to(device);
}
std::optional<infinicore::Tensor> block_tables_on_device;
if (block_tables.has_value()) {
block_tables_on_device = block_tables.value()->to(device);
}
std::optional<infinicore::Tensor> slot_mapping_on_device;
if (slot_mapping.has_value()) {
slot_mapping_on_device = slot_mapping.value()->to(device);
}
return {
input_ids, // @todo: on device in the future
position_ids_on_device,
cache_lengths_on_device,
input_lengths_on_device,
input_offsets_on_device,
block_tables_on_device,
slot_mapping_on_device};
} }
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
......
...@@ -206,7 +206,7 @@ void RankWorker::thread_loop() { ...@@ -206,7 +206,7 @@ void RankWorker::thread_loop() {
local_param_name = pending_param_name_; local_param_name = pending_param_name_;
local_param = pending_param_; local_param = pending_param_;
} else if (local_cmd == Command::RUN) { } else if (local_cmd == Command::RUN) {
local_args = pending_args_.to_model_input(); local_args = pending_args_.to_model_input(rank_info_.device);
} else if (local_cmd == Command::RESET_CACHE) { } else if (local_cmd == Command::RESET_CACHE) {
if (pending_cache_config_ != nullptr) { if (pending_cache_config_ != nullptr) {
local_cache_config = pending_cache_config_->unique_copy(); local_cache_config = pending_cache_config_->unique_copy();
...@@ -254,13 +254,18 @@ void RankWorker::thread_loop() { ...@@ -254,13 +254,18 @@ void RankWorker::thread_loop() {
auto random_val{pending_args_.random_val}; auto random_val{pending_args_.random_val};
const auto &logits_shape{logits->shape()}; const auto &logits_shape{logits->shape()};
const auto &batch_size{logits_shape[0]};
const auto &vocab_size{logits_shape[2]}; const auto &vocab_size{logits_shape[2]};
const auto &total_len{logits_shape[1]};
const auto &batch_size{logits_shape[0]};
auto n_req = pending_args_.input_offsets.value()->size(0);
int64_t *input_lengths = (int64_t *)pending_args_.input_lengths.value()->data();
int64_t *input_offsets = (int64_t *)pending_args_.input_offsets.value()->data();
auto output_ids{infinicore::Tensor::empty({batch_size}, infinicore::DataType::I32, rank_info_.device)}; auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
for (auto i{decltype(batch_size)(0)}; i < batch_size; ++i) { for (auto i{decltype(n_req)(0)}; i < n_req; ++i) {
auto score{logits->narrow({{0, i, 1}})->view({vocab_size})}; auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i] + input_lengths[i] - 1), 1}})->view({vocab_size})};
auto out{output_ids->narrow({{0, i, 1}})->view({})}; auto out{output_ids->narrow({{0, i, 1}})->view({})};
infinicore::op::random_sample_( infinicore::op::random_sample_(
out, score, random_val, top_p, top_k, temperature); out, score, random_val, top_p, top_k, temperature);
......
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
float random_val{0.1}; float random_val{0.1};
infinilm::InfinilmModel::Input to_model_input() const; infinilm::InfinilmModel::Input to_model_input(infinicore::Device device) const;
}; };
struct Output { struct Output {
......
#include "llama_attention.hpp" #include "llama_attention.hpp"
#include "../../utils.hpp"
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp" #include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
...@@ -43,6 +44,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -43,6 +44,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
} else { } else {
throw std::runtime_error("num_attention_heads / tp_size error."); throw std::runtime_error("num_attention_heads / tp_size error.");
} }
scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));
// Initialize projection layers // Initialize projection layers
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_, INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
...@@ -52,17 +54,10 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -52,17 +54,10 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
dtype, device, tp_rank, tp_size, rank_info.comm); dtype, device, tp_rank, tp_size, rank_info.comm);
} }
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> cache_lengths) const {
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) {
throw std::runtime_error("LlamaAttention: rotary_emb not configured");
}
// Input shape: [batch, seq_len, hidden_size] // Input shape: [batch, seq_len, hidden_size]
auto hidden_states_mutable = hidden_states; auto hidden_states_mutable = hidden_states;
auto shape = hidden_states->shape(); auto shape = hidden_states->shape();
...@@ -73,7 +68,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -73,7 +68,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
// 2. Reshape for multi-head attention // 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension // Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape) // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// The view operation requires the tensor to be contiguous in the required dimensions // The view operation requires the tensor to be contiguous in the required dimensions
...@@ -114,13 +108,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -114,13 +108,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_lengths.value()); auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_lengths.value());
k_total = k_total_tmp; k_total = k_total_tmp;
v_total = v_total_tmp; v_total = v_total_tmp;
} else if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
auto [k_total_tmp, v_total_tmp] = paged_kv_cache->update(layer_idx_, k_permuted, v_permuted, slot_mapping.value());
k_total = k_total_tmp;
v_total = v_total_tmp;
/// @todo Implement paged attention here.
throw std::runtime_error("LlamaAttention: Paged attention not implemented");
} else { } else {
throw std::runtime_error("LlamaAttention: Unsupported kvcache type"); throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
} }
...@@ -134,8 +121,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -134,8 +121,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len] auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
float scaling = 1.0f / std::sqrt(static_cast<float>(head_dim_)); auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [bs * n_kv_head, ng * seq_len, total_seq_len]
auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len}); auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax); infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
...@@ -152,6 +138,119 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -152,6 +138,119 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
return output; return output;
} }
infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
ASSERT(block_tables.has_value());
ASSERT(input_lengths.has_value());
ASSERT(slot_mapping.has_value());
// Input shape: [batch, seq_len, hidden_size]
auto hidden_states_mutable = hidden_states;
auto shape = hidden_states->shape();
size_t batch_size = shape[0];
size_t seq_len = shape[1];
// Only support batchsize==1, all requests should be flattened along seqlen dimension
ASSERT_EQ(batch_size, 1);
// Decode only if total_len == num_requests
bool is_prefill = (seq_len != input_lengths.value()->shape()[0]);
// 1. Project Q, K, V
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// The view operation requires the tensor to be contiguous in the required dimensions
auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_});
auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});
// 3. Prepare position_ids for RoPE - align with Python pattern
auto pos_shape = position_ids->shape();
infinicore::Tensor pos_ids_for_rope = position_ids;
if (pos_shape.size() == 2) {
auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
pos_ids_for_rope = pos_narrowed->view({pos_shape[1]});
} else if (pos_shape.size() == 1) {
pos_ids_for_rope = position_ids;
} else {
throw std::runtime_error("Unexpected position_ids shape");
}
// 4. Apply RoPE to Q and K
rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_q_head, head_dim]
rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim]
// 5. Prepare KV caches
// Ensure contiguous after permute for F16 compatibility with cache operations
auto [k_total, v_total] = paged_kv_cache->update(layer_idx_,
k_reshaped,
v_reshaped,
slot_mapping.value());
// 6. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (is_prefill) {
infinicore::op::paged_attention_prefill_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
cache_lengths.value(),
input_lengths.value(),
input_offsets.value(),
std::nullopt,
scaling_);
} else {
infinicore::op::paged_attention_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
cache_lengths.value(),
std::nullopt,
scaling_);
}
// 7. Project output
attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
return o_proj_->forward(attn_output);
}
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) {
throw std::runtime_error("LlamaAttention: rotary_emb not configured");
}
infinicore::Tensor output;
if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
} else {
output = forward_(hidden_states, position_ids, kv_cache, cache_lengths);
}
return output;
}
void LlamaAttention::set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) { void LlamaAttention::set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
rotary_emb_ = rotary_emb; rotary_emb_ = rotary_emb;
} }
......
...@@ -55,7 +55,7 @@ public: ...@@ -55,7 +55,7 @@ public:
std::optional<infinicore::Tensor> input_lengths, std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const; std::optional<infinicore::Tensor> slot_mapping) const;
/** /**
* @brief Get the layer index * @brief Get the layer index
...@@ -73,6 +73,21 @@ public: ...@@ -73,6 +73,21 @@ public:
size_t head_dim() const { return head_dim_; } size_t head_dim() const { return head_dim_; }
size_t hidden_size() const { return hidden_size_; } size_t hidden_size() const { return hidden_size_; }
private:
infinicore::Tensor forward_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths) const;
infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
protected: protected:
// Projection layers // Projection layers
INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj); INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
...@@ -93,6 +108,8 @@ private: ...@@ -93,6 +108,8 @@ private:
bool use_bias_; // Bias for Q/K/V projections bool use_bias_; // Bias for Q/K/V projections
bool use_output_bias_; // Bias for output projection (o_proj) bool use_output_bias_; // Bias for output projection (o_proj)
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -35,8 +35,7 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { ...@@ -35,8 +35,7 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto slot_mapping = input.slot_mapping; auto slot_mapping = input.slot_mapping;
// 1. Forward through base model to get hidden states // 1. Forward through base model to get hidden states
auto position_ids_device = position_ids->to(device_); auto hidden_states = model_->forward(input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
auto hidden_states = model_->forward(input_ids, position_ids_device, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
// 2. Apply language modeling head to get logits // 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states); auto logits = lm_head_->forward(hidden_states);
......
...@@ -59,15 +59,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -59,15 +59,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping); hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
} }
// 3. Apply final layer normalization to last token only (aligns with transformers) return norm_->forward(hidden_states);
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
auto shape = hidden_states->shape();
size_t seq_len = shape[1];
auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}});
auto normalized_last_token = norm_->forward(last_token);
return normalized_last_token;
} }
void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
......
...@@ -90,6 +90,8 @@ inline void bind_infer_engine(py::module &m) { ...@@ -90,6 +90,8 @@ inline void bind_infer_engine(py::module &m) {
std::move(input_ids), std::move(input_ids),
std::move(position_ids), std::move(position_ids),
std::move(cache_lengths), std::move(cache_lengths),
std::move(input_lengths),
std::move(input_offsets),
std::move(block_tables), std::move(block_tables),
std::move(slot_mapping)}}; std::move(slot_mapping)}};
......
...@@ -9,6 +9,7 @@ import sys ...@@ -9,6 +9,7 @@ import sys
import time import time
import os import os
import numpy as np import numpy as np
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
...@@ -82,6 +83,18 @@ def get_args(): ...@@ -82,6 +83,18 @@ def get_args():
default=1, default=1,
help="total rank for tensor parallel", help="total rank for tensor parallel",
) )
parser.add_argument(
"--enable-paged-attn",
action="store_true",
help="use paged cache",
)
parser.add_argument(
"--max-kvcache-size",
type=int,
default=8 * 1024 * 1024 * 1024,
help="max size (in bytes) allocated to paged kv cache",
)
return parser.parse_args() return parser.parse_args()
...@@ -92,6 +105,7 @@ def test( ...@@ -92,6 +105,7 @@ def test(
max_new_tokens=100, max_new_tokens=100,
infini_device=infinicore.device("cpu", 0), infini_device=infinicore.device("cpu", 0),
tp=1, tp=1,
enable_paged_attn=False,
): ):
model_path = os.path.expanduser(model_path) model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -150,11 +164,21 @@ def test( ...@@ -150,11 +164,21 @@ def test(
"input_ids" "input_ids"
] # List: [[1, 1128, 526, 366, 29892]] ] # List: [[1, 1128, 526, 366, 29892]]
# 根据输入长度和最长输出长度创建KVCache # ---------------------------------------------------------------------------- #
model.reset_cache( # 创建KVCache
1 if prompts is str else len(prompts), # ---------------------------------------------------------------------------- #
max_new_tokens + len(input_ids_list[0]), if enable_paged_attn:
) cache_config = PagedKVCacheConfig(
max_kv_memory_bytes=args.max_kvcache_size, block_size=16
)
else:
batch_size = 1 if prompts is str else len(prompts)
initial_capacity = max_new_tokens + len(input_ids_list[0])
cache_config = StaticKVCacheConfig(
max_batch_size=batch_size, max_cache_len=initial_capacity
)
model.reset_cache(cache_config)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 自回归生成 # 自回归生成
...@@ -211,7 +235,7 @@ if __name__ == "__main__": ...@@ -211,7 +235,7 @@ if __name__ == "__main__":
max_new_tokens = args.max_new_tokens max_new_tokens = args.max_new_tokens
backend = args.backend backend = args.backend
tp = args.tp tp = args.tp
enable_paged_attn = args.enable_paged_attn
if backend != "cpp": if backend != "cpp":
raise ValueError(f"Unsupported backend: {backend}.") raise ValueError(f"Unsupported backend: {backend}.")
...@@ -223,4 +247,5 @@ if __name__ == "__main__": ...@@ -223,4 +247,5 @@ if __name__ == "__main__":
max_new_tokens, max_new_tokens,
infini_device=infini_device, infini_device=infini_device,
tp=tp, tp=tp,
enable_paged_attn=enable_paged_attn,
) )
...@@ -21,5 +21,7 @@ class AutoConfig: ...@@ -21,5 +21,7 @@ class AutoConfig:
if config_dict["model_type"] == "llama": if config_dict["model_type"] == "llama":
return LlamaConfig(**config_dict) return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "qwen2":
return LlamaConfig(**config_dict)
raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")
from .cache import CacheConfig, StaticKVCacheConfig from .cache import CacheConfig, StaticKVCacheConfig,PagedKVCacheConfig
__all__ = ["CacheConfig", "StaticKVCacheConfig"] __all__ = ["CacheConfig", "StaticKVCacheConfig", "PagedKVCacheConfig"]
...@@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype): ...@@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype):
return ctypes.c_int32 return ctypes.c_int32
elif infini_dtype == infinicore.float32: elif infini_dtype == infinicore.float32:
return ctypes.c_float return ctypes.c_float
elif infini_dtype == infinicore.int64:
return ctypes.c_int64
else: else:
raise ValueError(f"Unsupported py_dtype: {infini_dtype}") raise ValueError(f"Unsupported py_dtype: {infini_dtype}")
......
...@@ -4,7 +4,7 @@ from dataclasses import dataclass ...@@ -4,7 +4,7 @@ from dataclasses import dataclass
import infinicore import infinicore
from infinilm.auto_config import AutoConfig from infinilm.auto_config import AutoConfig
from infinilm.cache import StaticKVCacheConfig from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
from infinilm.lib import _infinilm from infinilm.lib import _infinilm
...@@ -18,6 +18,7 @@ class GenerationConfig: ...@@ -18,6 +18,7 @@ class GenerationConfig:
top_p: float = 1.0 top_p: float = 1.0
eos_token_id: list[int] | None = None eos_token_id: list[int] | None = None
stop_on_eos: bool = True
class InferEngine(_infinilm.InferEngine): class InferEngine(_infinilm.InferEngine):
...@@ -42,6 +43,8 @@ class InferEngine(_infinilm.InferEngine): ...@@ -42,6 +43,8 @@ class InferEngine(_infinilm.InferEngine):
self.use_cache = False self.use_cache = False
self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
...@@ -93,15 +96,11 @@ class InferEngine(_infinilm.InferEngine): ...@@ -93,15 +96,11 @@ class InferEngine(_infinilm.InferEngine):
else: else:
eos_token_id = generation_config.eos_token_id eos_token_id = generation_config.eos_token_id
# TODO: Remove the `to_numpy` calls and simplify the corresponding code. past_seq_len = 0
batch_size, seq_len = input_ids.shape[:2]
position_ids = infinicore.from_list(
[list(range(0, seq_len)) for _ in range(batch_size)], dtype=infinicore.int64
)
cache_lengths = infinicore.from_list([0], dtype=infinicore.int64)
output_ids = [] output_ids = []
initial_batch_size, initial_seqlen = input_ids.shape[:2]
seq_len = initial_seqlen
batch_size = initial_batch_size
if batch_size != 1 and generation_config.max_new_tokens is None: if batch_size != 1 and generation_config.max_new_tokens is None:
raise ValueError( raise ValueError(
...@@ -111,14 +110,76 @@ class InferEngine(_infinilm.InferEngine): ...@@ -111,14 +110,76 @@ class InferEngine(_infinilm.InferEngine):
if _measure_and_log_time: if _measure_and_log_time:
time_measurements = [] time_measurements = []
for _ in range(0, generation_config.max_new_tokens): for iter in range(0, generation_config.max_new_tokens):
if _measure_and_log_time: if _measure_and_log_time:
start_time = time.perf_counter() start_time = time.perf_counter()
batch_size, seq_len = input_ids.shape[:2]
if self.enable_paged_attn:
input_ids = input_ids.view([1, batch_size * seq_len])
position_ids = infinicore.from_list(
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
dtype=infinicore.int64,
)
cache_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64
)
input_lengths = infinicore.from_list(
[seq_len] * batch_size, dtype=infinicore.int64
)
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size)], dtype=infinicore.int64
)
block_tables = infinicore.from_list(
[
[
i * batch_size + b
for i in range((past_seq_len + seq_len + 15) // 16)
]
for b in range(batch_size)
],
dtype=infinicore.int64,
)
slot_mapping = infinicore.from_list(
[
((past_seq_len + i + 15) // 16) * batch_size
+ b
+ (past_seq_len + i + 15) % 16
for i in range(seq_len)
for b in range(batch_size)
],
dtype=infinicore.int64,
)
else:
position_ids = infinicore.from_list(
[
list(range(past_seq_len, past_seq_len + seq_len))
for _ in range(batch_size)
],
dtype=infinicore.int64,
)
cache_lengths = infinicore.from_list(
[past_seq_len], dtype=infinicore.int64
)
input_lengths = infinicore.from_list(
[seq_len] * batch_size, dtype=infinicore.int64
)
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size)], dtype=infinicore.int64
)
block_tables = None
slot_mapping = None
output_id = self( output_id = self(
input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cache_lengths=cache_lengths, cache_lengths=cache_lengths,
input_lengths=input_lengths,
input_offsets=input_offsets,
block_tables=block_tables,
slot_mapping=slot_mapping,
temperature=generation_config.temperature, temperature=generation_config.temperature,
top_k=generation_config.top_k, top_k=generation_config.top_k,
top_p=generation_config.top_p, top_p=generation_config.top_p,
...@@ -127,24 +188,16 @@ class InferEngine(_infinilm.InferEngine): ...@@ -127,24 +188,16 @@ class InferEngine(_infinilm.InferEngine):
output_ids.append(output_id) output_ids.append(output_id)
if ( if (
generation_config.max_new_tokens is not None generation_config.stop_on_eos
and generation_config.max_new_tokens is not None
and output_id.to_numpy()[0] in eos_token_id and output_id.to_numpy()[0] in eos_token_id
): ):
break break
seq_len = position_ids.shape[-1]
input_ids = infinicore.from_list( input_ids = infinicore.from_list(
[[output_id] for output_id in output_id.to_numpy().tolist()] [[output_id] for output_id in output_id.to_numpy().tolist()]
) )
position_ids = infinicore.from_list( past_seq_len = past_seq_len + seq_len
[1 for _ in range(batch_size)],
dtype=position_ids.dtype,
device=position_ids.device,
).view((batch_size, 1)) + position_ids.narrow(1, seq_len - 1, 1)
cache_lengths += infinicore.from_list(
[seq_len], dtype=cache_lengths.dtype, device=cache_lengths.device
)
if _measure_and_log_time: if _measure_and_log_time:
end_time = time.perf_counter() end_time = time.perf_counter()
...@@ -156,23 +209,21 @@ class InferEngine(_infinilm.InferEngine): ...@@ -156,23 +209,21 @@ class InferEngine(_infinilm.InferEngine):
f"\n\n\n Generation completed in {round(sum(time_measurements) * 1000, 2)} ms" f"\n\n\n Generation completed in {round(sum(time_measurements) * 1000, 2)} ms"
) )
print( print(
f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_measurements)}\n" f" Batchsize={initial_batch_size} Per_Batch_Input_Len={initial_seqlen} Per_Batch_New_Tokens={len(time_measurements)}\n"
) )
print( print(
f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((batch_size * seq_len) / time_measurements[0], 2)}tok/s\n", f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((initial_batch_size * initial_seqlen) / time_measurements[0], 2)}tok/s\n",
) )
if len(time_measurements) > 1: if len(time_measurements) > 1:
print( print(
f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n", f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((initial_batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n",
) )
return output_ids return output_ids
def reset_cache(self, batch_size: int, initial_capacity: int = 1024): def reset_cache(self, cache_config):
infinicore.sync_device() infinicore.sync_device()
self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig)
cache_config = StaticKVCacheConfig(batch_size, initial_capacity)
super().reset_cache(cache_config) super().reset_cache(cache_config)
def state_dict_keyname(self): def state_dict_keyname(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment