Unverified Commit c1a3ab29 authored by Haojie Wang's avatar Haojie Wang Committed by GitHub
Browse files

Merge pull request #173 from InfiniTensor/issue/168

Issue/168 InfiniLM接入paged attention接口
parents 96e53dbb 09ab8fa4
#include "kv_cache.hpp"
#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include <stdexcept>
namespace infinilm::cache {
......@@ -80,12 +80,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache::update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &cache_lengths) {
const infinicore::Tensor &past_sequence_lengths) {
ASSERT(layer_idx < rank_num_layers_);
auto batch_size = k->size(0);
auto update_len = k->size(2);
size_t cache_pos = reinterpret_cast<int64_t *>(cache_lengths->to(infinicore::Device::cpu())->data())[0];
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
......@@ -111,9 +111,9 @@ StaticKVCache::update(size_t layer_idx,
// PagedKVCacheConfig
// ==========================
PagedKVCacheConfig::PagedKVCacheConfig(
size_t max_kv_memory_bytes,
size_t num_blocks,
size_t block_size)
: max_kv_memory_bytes_(max_kv_memory_bytes),
: num_blocks_(num_blocks),
block_size_(block_size) {
}
......@@ -123,8 +123,8 @@ PagedKVCacheConfig::unique_copy() const {
}
size_t
PagedKVCacheConfig::max_kv_memory_bytes() const {
return max_kv_memory_bytes_;
PagedKVCacheConfig::num_blocks() const {
return num_blocks_;
}
size_t
......@@ -151,15 +151,8 @@ PagedKVCache::PagedKVCache(
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_num_layers_(num_layers),
dtype_(dtype),
num_blocks_per_layer_(config.num_blocks()),
block_size_(config.block_size()) {
num_blocks_per_layer_ = config.max_kv_memory_bytes()
/ (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_)
/ block_size_
/ infinicore::dsize(dtype_);
if (num_blocks_per_layer_ == 0) {
throw std::runtime_error("Not enough memory for KV cache");
}
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_,
......@@ -187,11 +180,79 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping) {
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
infinicore::op::paged_caching_(
k_cache_layer,
v_cache_layer,
k,
v,
slot_mapping);
return {k_cache_layer, v_cache_layer};
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
PagedKVCache::get_paged_kv(size_t layer_idx) {
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
return {k_cache_layer, v_cache_layer};
}
/// @todo: implement paged cache update here
std::tuple<infinicore::Tensor, infinicore::Tensor>
PagedKVCache::get_contiguous_kv(
size_t layer_idx,
const infinicore::Tensor block_tables,
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id) {
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64);
return {k_cache_layer, v_cache_layer};
auto nreq = block_tables->size(0);
auto block_tables_cpu = block_tables->to(infinicore::Device::cpu());
auto cache_lens_cpu = cache_lens->to(infinicore::Device::cpu());
auto input_offsets_cpu = input_offsets->to(infinicore::Device::cpu());
infinicore::context::syncDevice();
// [num_blocks, num_rank_v_heads, block_size, v_dim]
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
auto req = request_id;
auto cache_lens_ptr = reinterpret_cast<const int64_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int64_t *>(input_offsets_cpu->data());
int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);
auto full_k = infinicore::Tensor::empty(
{num_rank_k_heads_, (size_t)total_len, k_dim_},
k_cache_layer->dtype(), k_cache_layer->device());
auto full_v = infinicore::Tensor::empty(
{num_rank_v_heads_, (size_t)total_len, v_dim_},
v_cache_layer->dtype(), v_cache_layer->device());
size_t nblocks = total_len / block_size_;
size_t r = total_len % block_size_;
for (size_t b = 0; b < nblocks; b++) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));
full_k->narrow({{1, b * block_size_, block_size_}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
full_v->narrow({{1, b * block_size_, block_size_}})
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
}
if (r > 0) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));
full_k->narrow({{1, nblocks * block_size_, r}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
full_v->narrow({{1, nblocks * block_size_, r}})
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
}
return {full_k, full_v};
}
} // namespace infinilm::cache
......@@ -61,7 +61,7 @@ public:
update(size_t layer_idx,
const infinicore::Tensor &k,
const infinicore::Tensor &v,
const infinicore::Tensor &cache_lengths);
const infinicore::Tensor &past_sequence_lengths);
~StaticKVCache() override = default;
......@@ -85,15 +85,15 @@ private:
class PagedKVCacheConfig final : public CacheConfig {
public:
PagedKVCacheConfig(
size_t max_kv_memory_bytes,
size_t num_blocks,
size_t block_size = 16);
std::unique_ptr<CacheConfig> unique_copy() const override;
size_t max_kv_memory_bytes() const;
size_t num_blocks() const;
size_t block_size() const;
private:
size_t max_kv_memory_bytes_;
size_t num_blocks_;
size_t block_size_;
};
......@@ -113,7 +113,7 @@ public:
/**
* @brief Update Paged KV cache at a given layer given slot info for each token.
*
* @param layer_idx Which transformer layer
* @param layer_idx Which paged attention layer
* @param k [num_rank_k_heads, seq_len, k_dim]
* @param v [num_rank_v_heads, seq_len, v_dim]
* @param slot_mapping [seq_len]
......@@ -128,7 +128,41 @@ public:
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping);
~PagedKVCache() override = default;
/**
* @brief Get Paged KV cache at a given layer.
*
* @param layer_idx Which paged attention layer
*
* @return (full_k, full_v)
* full_k: [num_blocks, num_rank_k_heads, block_size, k_dim]
* full_v: [num_blocks, num_rank_v_heads, block_size, v_dim]
*/
std::tuple<infinicore::Tensor, infinicore::Tensor>
get_paged_kv(size_t layer_idx);
/**
* @brief Get contiguous KV cache at a given layer, given the request info
* among a continuous request batch.
*
* @param layer_idx Which paged attention layer
* @param block_tables [num_requests, max_blocks_per_request]
* @param cache_lens [num_requests]
* @param input_offsets [num_requests + 1]
* @param request_id Which request among a continuous batch of requests
*
* @return (full_k, full_v)
* full_k: [num_rank_k_heads, total_len, k_dim]
* full_v: [num_rank_v_heads, total_len, v_dim]
*/
std::tuple<infinicore::Tensor, infinicore::Tensor>
get_contiguous_kv(size_t layer_idx,
const infinicore::Tensor block_tables,
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id = 0);
~PagedKVCache() override
= default;
private:
infinicore::Size k_dim_;
......
......@@ -56,8 +56,23 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------
// forward
//------------------------------------------------------
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const {
return {input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping};
infinilm::InfinilmModel::Input
InferEngine::Input::to_model_input(infinicore::Device device) const {
auto to_device = [&](const std::optional<infinicore::Tensor> &t)
-> std::optional<infinicore::Tensor> {
return t.has_value() ? t.value()->to(device) : t;
};
return {
input_ids, // @todo: on device in the future
to_device(position_ids),
past_sequence_lengths, // @todo: on device in the future
to_device(total_sequence_lengths),
to_device(input_offsets),
to_device(block_tables),
to_device(slot_mapping),
};
}
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
......
......@@ -188,7 +188,7 @@ void RankWorker::thread_loop() {
Command local_cmd = Command::INIT;
std::string local_param_name;
infinicore::Tensor local_param;
InfinilmModel::Input local_args;
Input local_args;
std::unique_ptr<cache::CacheConfig> local_cache_config;
// Wait for a job or exit
......@@ -206,7 +206,7 @@ void RankWorker::thread_loop() {
local_param_name = pending_param_name_;
local_param = pending_param_;
} else if (local_cmd == Command::RUN) {
local_args = pending_args_.to_model_input();
local_args = pending_args_;
} else if (local_cmd == Command::RESET_CACHE) {
if (pending_cache_config_ != nullptr) {
local_cache_config = pending_cache_config_->unique_copy();
......@@ -244,23 +244,28 @@ void RankWorker::thread_loop() {
{
std::lock_guard<std::mutex> lk(mutex_);
auto logits{model_->forward(local_args).logits};
auto model_args = local_args.to_model_input(rank_info_.device);
// Forward calculation
auto logits{model_->forward(model_args).logits};
// Random sampling (rank 0 only)
if (rank_info_.tp_rank == 0) {
// Perform random sampling.
auto temperature{pending_args_.temperature};
auto top_p{pending_args_.top_p};
auto top_k{pending_args_.top_k};
auto random_val{pending_args_.random_val};
auto temperature{local_args.temperature};
auto top_p{local_args.top_p};
auto top_k{local_args.top_k};
auto random_val{local_args.random_val};
const auto &logits_shape{logits->shape()};
const auto &batch_size{logits_shape[0]};
const auto &vocab_size{logits_shape[2]};
const auto &total_len{logits_shape[1]};
const auto &batch_size{logits_shape[0]};
auto n_req = local_args.input_offsets.value()->size(0) - 1;
int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data();
auto output_ids{infinicore::Tensor::empty({batch_size}, infinicore::DataType::I32, rank_info_.device)};
auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
for (auto i{decltype(batch_size)(0)}; i < batch_size; ++i) {
auto score{logits->narrow({{0, i, 1}})->view({vocab_size})};
for (auto i{decltype(n_req)(0)}; i < n_req; ++i) {
auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i + 1] - 1), 1}})->view({vocab_size})};
auto out{output_ids->narrow({{0, i, 1}})->view({})};
infinicore::op::random_sample_(
out, score, random_val, top_p, top_k, temperature);
......
......@@ -29,9 +29,9 @@ public:
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std::optional<infinicore::Tensor> position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std::optional<infinicore::Tensor> cache_lengths;
/// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_lengths;
std::optional<infinicore::Tensor> past_sequence_lengths;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_offsets;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
......@@ -47,7 +47,7 @@ public:
float random_val{0.1};
infinilm::InfinilmModel::Input to_model_input() const;
infinilm::InfinilmModel::Input to_model_input(infinicore::Device device) const;
};
struct Output {
......
......@@ -23,10 +23,10 @@ public:
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std::optional<infinicore::Tensor> position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std::optional<infinicore::Tensor> cache_lengths;
/// Input Lengths of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> past_sequence_lengths;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
......
#include "llama_attention.hpp"
#include "../../utils.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
......@@ -43,6 +44,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
} else {
throw std::runtime_error("num_attention_heads / tp_size error.");
}
scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));
// Initialize projection layers
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
......@@ -52,17 +54,11 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
dtype, device, tp_rank, tp_size, rank_info.comm);
}
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) {
throw std::runtime_error("LlamaAttention: rotary_emb not configured");
}
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths) const {
// Input shape: [batch, seq_len, hidden_size]
auto hidden_states_mutable = hidden_states;
auto shape = hidden_states->shape();
......@@ -73,7 +69,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// The view operation requires the tensor to be contiguous in the required dimensions
......@@ -111,16 +106,9 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
k_total = k_permuted;
v_total = v_permuted;
} else if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) {
auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_lengths.value());
k_total = k_total_tmp;
v_total = v_total_tmp;
} else if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
auto [k_total_tmp, v_total_tmp] = paged_kv_cache->update(layer_idx_, k_permuted, v_permuted, slot_mapping.value());
auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, past_sequence_lengths.value());
k_total = k_total_tmp;
v_total = v_total_tmp;
/// @todo Implement paged attention here.
throw std::runtime_error("LlamaAttention: Paged attention not implemented");
} else {
throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
}
......@@ -134,8 +122,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
float scaling = 1.0f / std::sqrt(static_cast<float>(head_dim_));
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [bs * n_kv_head, ng * seq_len, total_seq_len]
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
......@@ -152,6 +139,116 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
return output;
}
infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
ASSERT(block_tables.has_value());
ASSERT(slot_mapping.has_value());
// Input shape: [batch, seq_len, hidden_size]
auto hidden_states_mutable = hidden_states;
auto shape = hidden_states->shape();
size_t batch_size = shape[0];
size_t seq_len = shape[1];
// Only support batchsize==1, all requests should be flattened along seqlen dimension
ASSERT_EQ(batch_size, 1);
// Decode only if total_len == num_requests
bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);
// 1. Project Q, K, V
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// The view operation requires the tensor to be contiguous in the required dimensions
auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_});
auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});
// 3. Prepare position_ids for RoPE - align with Python pattern
auto pos_shape = position_ids->shape();
infinicore::Tensor pos_ids_for_rope = position_ids;
if (pos_shape.size() == 2) {
auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
pos_ids_for_rope = pos_narrowed->view({pos_shape[1]});
} else if (pos_shape.size() == 1) {
pos_ids_for_rope = position_ids;
} else {
throw std::runtime_error("Unexpected position_ids shape");
}
// 4. Apply RoPE to Q and K
rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_q_head, head_dim]
rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim]
// 5. Prepare KV caches
// Ensure contiguous after permute for F16 compatibility with cache operations
auto [k_total, v_total] = paged_kv_cache->update(layer_idx_,
k_reshaped,
v_reshaped,
slot_mapping.value());
// 6. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (is_prefill) {
infinicore::op::paged_attention_prefill_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
input_offsets.value(),
std::nullopt,
scaling_);
} else {
infinicore::op::paged_attention_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
std::nullopt,
scaling_);
}
// 7. Project output
attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
return o_proj_->forward(attn_output);
}
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) {
throw std::runtime_error("LlamaAttention: rotary_emb not configured");
}
infinicore::Tensor output;
if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
} else {
output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths);
}
return output;
}
void LlamaAttention::set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
rotary_emb_ = rotary_emb;
}
......
......@@ -51,11 +51,11 @@ public:
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
std::optional<infinicore::Tensor> slot_mapping) const;
/**
* @brief Get the layer index
......@@ -73,6 +73,21 @@ public:
size_t head_dim() const { return head_dim_; }
size_t hidden_size() const { return hidden_size_; }
private:
infinicore::Tensor forward_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths) const;
infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
protected:
// Projection layers
INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
......@@ -93,6 +108,8 @@ private:
bool use_bias_; // Bias for Q/K/V projections
bool use_output_bias_; // Bias for output projection (o_proj)
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_;
};
} // namespace infinilm::models::llama
......@@ -26,8 +26,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
......@@ -38,7 +38,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
auto normed_states = input_layernorm_->forward(hidden_states);
// 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output);
......
......@@ -49,8 +49,8 @@ public:
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
......
......@@ -28,15 +28,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto input_ids = input.input_ids.value();
auto position_ids = input.position_ids.value();
auto cache_lengths = input.cache_lengths;
auto input_lengths = input.input_lengths;
auto past_sequence_lengths = input.past_sequence_lengths;
auto total_sequence_length = input.total_sequence_lengths;
auto input_offsets = input.input_offsets;
auto block_tables = input.block_tables;
auto slot_mapping = input.slot_mapping;
// 1. Forward through base model to get hidden states
auto position_ids_device = position_ids->to(device_);
auto hidden_states = model_->forward(input_ids, position_ids_device, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
auto hidden_states = model_->forward(
input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping);
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
......
......@@ -45,8 +45,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
......@@ -56,18 +56,10 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// 2. Process through all decoder layers
size_t num_layers = layers_.size();
for (size_t i = 0; i < num_layers; ++i) {
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
}
// 3. Apply final layer normalization to last token only (aligns with transformers)
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
auto shape = hidden_states->shape();
size_t seq_len = shape[1];
auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}});
auto normalized_last_token = norm_->forward(last_token);
return normalized_last_token;
return norm_->forward(hidden_states);
}
void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
......
......@@ -48,15 +48,15 @@ public:
* @param input_ids Token IDs tensor of shape [batch, seq_len]. Batch is 1 when continuous batch is used,
* and tokens from all requests are concatenated along seq_len dimension.
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param cache_lengths Cache positions tensor of shape [n_req]
* @param input_lengths Input lengths tensor in a continuous batch of shape [n_req]
* @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req]
* @param past_sequence_lengths Cache positions tensor of shape [n_req]
* @param total_sequence_lengths Total sequence lengths tensor of shape [n_req]
* @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req + 1]
* @return Output tensor of shape [batch, seq_len, hidden_size]
*/
infinicore::Tensor forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
......
......@@ -36,11 +36,11 @@ inline void bind_cache(py::module &m) {
std::shared_ptr<infinilm::cache::PagedKVCacheConfig>>(m, "PagedKVCacheConfig")
.def(
py::init<size_t, size_t>(),
py::arg("max_kv_memory_bytes"),
py::arg("num_blocks"),
py::arg("block_size") = 16)
.def(
"max_kv_memory_bytes",
&infinilm::cache::PagedKVCacheConfig::max_kv_memory_bytes)
"num_blocks",
&infinilm::cache::PagedKVCacheConfig::num_blocks)
.def(
"block_size",
&infinilm::cache::PagedKVCacheConfig::block_size)
......
......@@ -80,28 +80,48 @@ inline void bind_infer_engine(py::module &m) {
py::init([](
std::optional<infinicore::Tensor> input_ids,
std::optional<infinicore::Tensor> position_ids,
std::optional<infinicore::Tensor> cache_lengths,
std::optional<infinicore::Tensor> input_lengths,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping,
py::kwargs kwargs) {
auto input{InferEngine::Input{
InferEngine::Input input{
std::move(input_ids),
std::move(position_ids),
std::move(cache_lengths),
std::move(past_sequence_lengths),
std::move(total_sequence_lengths),
std::move(input_offsets),
std::move(block_tables),
std::move(slot_mapping)}};
std::move(slot_mapping),
};
if (kwargs) {
if (kwargs.contains("temperature")) {
input.temperature = kwargs["temperature"].cast<float>();
}
if (kwargs.contains("top_k")) {
input.top_k = kwargs["top_k"].cast<int>();
// Explicit defaults
input.temperature = 1.0f;
input.top_p = 1.0f;
input.top_k = 1;
// Allowed keyword arguments
static const std::unordered_set<std::string> allowed_kwargs = {
"temperature",
"top_p",
"top_k",
};
for (auto &item : kwargs) {
const std::string key = py::cast<std::string>(item.first);
if (allowed_kwargs.find(key) == allowed_kwargs.end()) {
throw py::value_error(
"InferEngine.Input got an unexpected keyword argument '" + key + "'");
}
if (kwargs.contains("top_p")) {
input.top_p = kwargs["top_p"].cast<float>();
if (key == "temperature") {
input.temperature = py::cast<float>(item.second);
} else if (key == "top_p") {
input.top_p = py::cast<float>(item.second);
} else if (key == "top_k") {
input.top_k = py::cast<int>(item.second);
}
}
......@@ -109,18 +129,21 @@ inline void bind_infer_engine(py::module &m) {
}),
py::arg("input_ids") = std::nullopt,
py::arg("position_ids") = std::nullopt,
py::arg("cache_lengths") = std::nullopt,
py::arg("input_lengths") = std::nullopt,
py::arg("past_sequence_lengths") = std::nullopt,
py::arg("total_sequence_lengths") = std::nullopt,
py::arg("input_offsets") = std::nullopt,
py::arg("block_tables") = std::nullopt,
py::arg("slot_mapping") = std::nullopt)
.def_readwrite("input_ids", &InferEngine::Input::input_ids)
.def_readwrite("position_ids", &InferEngine::Input::position_ids)
.def_readwrite("cache_lengths", &InferEngine::Input::cache_lengths)
.def_readwrite("input_lengths", &InferEngine::Input::input_lengths)
.def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
.def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
.def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
.def_readwrite("block_tables", &InferEngine::Input::block_tables)
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping);
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
.def_readwrite("temperature", &InferEngine::Input::temperature)
.def_readwrite("top_k", &InferEngine::Input::top_k)
.def_readwrite("top_p", &InferEngine::Input::top_p);
py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor");
......
......@@ -3,6 +3,7 @@ from transformers import AutoTokenizer
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
import argparse
import sys
import time
......@@ -260,6 +261,7 @@ class TestModel:
output_ids = self.model.generate(
input_ids_infini,
GenerationConfig(max_new_tokens=output_len, eos_token_id=[]),
_measure_and_log_time=True,
)
t2 = time.time()
......@@ -336,7 +338,11 @@ if __name__ == "__main__":
# reset cache for each case
initial_capacity = input_len + output_len
test.model.reset_cache(batch_size=batch_size, initial_capacity=initial_capacity)
test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=batch_size, max_cache_len=initial_capacity
)
)
# run test one case
test.run(
......
......@@ -9,6 +9,7 @@ import sys
import time
import os
import numpy as np
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
......@@ -82,6 +83,11 @@ def get_args():
default=1,
help="total rank for tensor parallel",
)
parser.add_argument(
"--enable-paged-attn",
action="store_true",
help="use paged cache",
)
return parser.parse_args()
......@@ -92,10 +98,11 @@ def test(
max_new_tokens=100,
infini_device=infinicore.device("cpu", 0),
tp=1,
enable_paged_attn=False,
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
# 创建模型,
# Create Model
# ---------------------------------------------------------------------------- #
model = InferEngine(
model_path,
......@@ -104,12 +111,12 @@ def test(
)
# ---------------------------------------------------------------------------- #
# 加载权重
# Load Weights
# ---------------------------------------------------------------------------- #
load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype)
# ---------------------------------------------------------------------------- #
# 创建 tokenizer
# create tokenizer
# ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
......@@ -132,7 +139,7 @@ def test(
)
# ---------------------------------------------------------------------------- #
# token编码
# tokenize
# ---------------------------------------------------------------------------- #
# prompt = "山东最高的山是?"
if isinstance(prompts, str):
......@@ -150,14 +157,26 @@ def test(
"input_ids"
] # List: [[1, 1128, 526, 366, 29892]]
# 根据输入长度和最长输出长度创建KVCache
model.reset_cache(
1 if prompts is str else len(prompts),
max_new_tokens + len(input_ids_list[0]),
)
# ---------------------------------------------------------------------------- #
# Create KVCache
# ---------------------------------------------------------------------------- #
if enable_paged_attn:
batch_size = 1 if prompts is str else len(prompts)
max_total_tokens = max_new_tokens + len(input_ids_list[0])
cache_config = PagedKVCacheConfig(
num_blocks=(max_total_tokens // 16 + 1) * batch_size, block_size=16
)
else:
batch_size = 1 if prompts is str else len(prompts)
initial_capacity = max_new_tokens + len(input_ids_list[0])
cache_config = StaticKVCacheConfig(
max_batch_size=batch_size, max_cache_len=initial_capacity
)
model.reset_cache(cache_config)
# ---------------------------------------------------------------------------- #
# 自回归生成
# Generate
# ---------------------------------------------------------------------------- #
print(input_contents[0], end="", flush=True)
input_ids_infini = infinicore.from_list(input_ids_list)
......@@ -211,7 +230,7 @@ if __name__ == "__main__":
max_new_tokens = args.max_new_tokens
backend = args.backend
tp = args.tp
enable_paged_attn = args.enable_paged_attn
if backend != "cpp":
raise ValueError(f"Unsupported backend: {backend}.")
......@@ -223,4 +242,5 @@ if __name__ == "__main__":
max_new_tokens,
infini_device=infini_device,
tp=tp,
enable_paged_attn=enable_paged_attn,
)
......@@ -21,5 +21,7 @@ class AutoConfig:
if config_dict["model_type"] == "llama":
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "qwen2":
return LlamaConfig(**config_dict)
raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")
from .cache import CacheConfig, StaticKVCacheConfig
from .cache import CacheConfig, StaticKVCacheConfig,PagedKVCacheConfig
__all__ = ["CacheConfig", "StaticKVCacheConfig"]
__all__ = ["CacheConfig", "StaticKVCacheConfig", "PagedKVCacheConfig"]
......@@ -16,11 +16,11 @@ class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig):
def __init__(
self,
max_kv_memory_bytes: int,
num_blocks: int,
block_size: int = 16,
):
_infinilm.PagedKVCacheConfig.__init__(
self,
max_kv_memory_bytes,
num_blocks,
block_size,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment