Commit 831e8a67 authored by PanZezhong's avatar PanZezhong
Browse files

issue/168 support fixed paged attention api

parent e48b5b0d
...@@ -80,12 +80,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> ...@@ -80,12 +80,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache::update(size_t layer_idx, StaticKVCache::update(size_t layer_idx,
const infinicore::Tensor &k, const infinicore::Tensor &k,
const infinicore::Tensor &v, const infinicore::Tensor &v,
const infinicore::Tensor &cache_lengths) { const infinicore::Tensor &past_sequence_lengths) {
ASSERT(layer_idx < rank_num_layers_); ASSERT(layer_idx < rank_num_layers_);
auto batch_size = k->size(0); auto batch_size = k->size(0);
auto update_len = k->size(2); auto update_len = k->size(2);
size_t cache_pos = reinterpret_cast<int64_t *>(cache_lengths->to(infinicore::Device::cpu())->data())[0]; size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len; auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_); ASSERT(result_len <= cache_len_);
......
...@@ -61,7 +61,7 @@ public: ...@@ -61,7 +61,7 @@ public:
update(size_t layer_idx, update(size_t layer_idx,
const infinicore::Tensor &k, const infinicore::Tensor &k,
const infinicore::Tensor &v, const infinicore::Tensor &v,
const infinicore::Tensor &cache_lengths); const infinicore::Tensor &past_sequence_lengths);
~StaticKVCache() override = default; ~StaticKVCache() override = default;
......
...@@ -56,44 +56,23 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng ...@@ -56,44 +56,23 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------ //------------------------------------------------------
// forward // forward
//------------------------------------------------------ //------------------------------------------------------
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input(infinicore::Device device) const { infinilm::InfinilmModel::Input
InferEngine::Input::to_model_input(infinicore::Device device) const {
std::optional<infinicore::Tensor> position_ids_on_device; auto to_device = [&](const std::optional<infinicore::Tensor> &t)
if (position_ids.has_value()) { -> std::optional<infinicore::Tensor> {
position_ids_on_device = position_ids.value()->to(device); return t.has_value() ? t.value()->to(device) : t;
} };
std::optional<infinicore::Tensor> cache_lengths_on_device;
if (cache_lengths.has_value()) {
if (block_tables.has_value()) {
cache_lengths_on_device = cache_lengths.value()->to(device);
} else { // @todo: only paged kv cache support device tensor so far
cache_lengths_on_device = cache_lengths.value();
}
}
std::optional<infinicore::Tensor> input_offsets_on_device;
if (input_offsets.has_value()) {
input_offsets_on_device = input_offsets.value()->to(device);
}
std::optional<infinicore::Tensor> block_tables_on_device;
if (block_tables.has_value()) {
block_tables_on_device = block_tables.value()->to(device);
}
std::optional<infinicore::Tensor> slot_mapping_on_device;
if (slot_mapping.has_value()) {
slot_mapping_on_device = slot_mapping.value()->to(device);
}
return { return {
input_ids, // @todo: on device in the future input_ids, // @todo: on device in the future
position_ids_on_device, to_device(position_ids),
cache_lengths_on_device, past_sequence_lengths, // @todo: on device in the future
input_offsets_on_device, to_device(total_sequence_lengths),
block_tables_on_device, to_device(input_offsets),
slot_mapping_on_device}; to_device(block_tables),
to_device(slot_mapping),
};
} }
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
......
...@@ -29,7 +29,9 @@ public: ...@@ -29,7 +29,9 @@ public:
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`. /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std::optional<infinicore::Tensor> position_ids; std::optional<infinicore::Tensor> position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`. /// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std::optional<infinicore::Tensor> cache_lengths; std::optional<infinicore::Tensor> past_sequence_lengths;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`. /// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> input_offsets; std::optional<infinicore::Tensor> input_offsets;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
......
...@@ -23,7 +23,9 @@ public: ...@@ -23,7 +23,9 @@ public:
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`. /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std::optional<infinicore::Tensor> position_ids; std::optional<infinicore::Tensor> position_ids;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`. /// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std::optional<infinicore::Tensor> cache_lengths; std::optional<infinicore::Tensor> past_sequence_lengths;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets; std::optional<infinicore::Tensor> input_offsets;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
......
...@@ -57,7 +57,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -57,7 +57,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths) const { std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths) const {
// Input shape: [batch, seq_len, hidden_size] // Input shape: [batch, seq_len, hidden_size]
auto hidden_states_mutable = hidden_states; auto hidden_states_mutable = hidden_states;
auto shape = hidden_states->shape(); auto shape = hidden_states->shape();
...@@ -105,7 +106,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -105,7 +106,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
k_total = k_permuted; k_total = k_permuted;
v_total = v_permuted; v_total = v_permuted;
} else if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) { } else if (auto static_kv_cache = std::dynamic_pointer_cast<cache::StaticKVCache>(kv_cache)) {
auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_lengths.value()); auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, past_sequence_lengths.value());
k_total = k_total_tmp; k_total = k_total_tmp;
v_total = v_total_tmp; v_total = v_total_tmp;
} else { } else {
...@@ -141,7 +142,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -141,7 +142,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache, std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
...@@ -157,7 +158,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -157,7 +158,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
// Only support batchsize==1, all requests should be flattened along seqlen dimension // Only support batchsize==1, all requests should be flattened along seqlen dimension
ASSERT_EQ(batch_size, 1); ASSERT_EQ(batch_size, 1);
// Decode only if total_len == num_requests // Decode only if total_len == num_requests
bool is_prefill = (seq_len != cache_lengths.value()->shape()[0]); bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);
// 1. Project Q, K, V // 1. Project Q, K, V
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
...@@ -204,7 +205,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -204,7 +205,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
k_total, k_total,
v_total, v_total,
block_tables.value(), block_tables.value(),
cache_lengths.value(), total_sequence_lengths.value(),
input_offsets.value(), input_offsets.value(),
std::nullopt, std::nullopt,
scaling_); scaling_);
...@@ -216,7 +217,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -216,7 +217,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
k_total, k_total,
v_total, v_total,
block_tables.value(), block_tables.value(),
cache_lengths.value(), total_sequence_lengths.value(),
std::nullopt, std::nullopt,
scaling_); scaling_);
} }
...@@ -229,7 +230,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -229,7 +230,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<cache::Cache> kv_cache, std::shared_ptr<cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
...@@ -239,10 +241,10 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -239,10 +241,10 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
infinicore::Tensor output; infinicore::Tensor output;
if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) { if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, cache_lengths, input_offsets, block_tables, slot_mapping); output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
} else { } else {
output = forward_(hidden_states, position_ids, kv_cache, cache_lengths); output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths);
} }
return output; return output;
} }
......
...@@ -51,7 +51,8 @@ public: ...@@ -51,7 +51,8 @@ public:
infinicore::Tensor forward(const infinicore::Tensor &hidden_states, infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
...@@ -76,12 +77,13 @@ private: ...@@ -76,12 +77,13 @@ private:
infinicore::Tensor forward_(const infinicore::Tensor &hidden_states, infinicore::Tensor forward_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths) const; std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths) const;
infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states, infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache, std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
......
...@@ -26,7 +26,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, ...@@ -26,7 +26,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
...@@ -37,7 +38,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s ...@@ -37,7 +38,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
auto normed_states = input_layernorm_->forward(hidden_states); auto normed_states = input_layernorm_->forward(hidden_states);
// 2. Self-attention with residual connection // 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_lengths, input_offsets, block_tables, slot_mapping); auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// Add residual: hidden_states = hidden_states + attn_output // Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output); auto output = infinicore::op::add(residual, attn_output);
......
...@@ -49,7 +49,8 @@ public: ...@@ -49,7 +49,8 @@ public:
infinicore::Tensor forward(const infinicore::Tensor &hidden_states, infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const; std::optional<infinicore::Tensor> slot_mappin) const;
......
...@@ -28,13 +28,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, ...@@ -28,13 +28,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto input_ids = input.input_ids.value(); auto input_ids = input.input_ids.value();
auto position_ids = input.position_ids.value(); auto position_ids = input.position_ids.value();
auto cache_lengths = input.cache_lengths; auto past_sequence_lengths = input.past_sequence_lengths;
auto total_sequence_length = input.total_sequence_lengths;
auto input_offsets = input.input_offsets; auto input_offsets = input.input_offsets;
auto block_tables = input.block_tables; auto block_tables = input.block_tables;
auto slot_mapping = input.slot_mapping; auto slot_mapping = input.slot_mapping;
// 1. Forward through base model to get hidden states // 1. Forward through base model to get hidden states
auto hidden_states = model_->forward(input_ids, position_ids, cache_lengths, input_offsets, block_tables, slot_mapping); auto hidden_states = model_->forward(
input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping);
// 2. Apply language modeling head to get logits // 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states); auto logits = lm_head_->forward(hidden_states);
......
...@@ -45,7 +45,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -45,7 +45,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
...@@ -55,7 +56,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -55,7 +56,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// 2. Process through all decoder layers // 2. Process through all decoder layers
size_t num_layers = layers_.size(); size_t num_layers = layers_.size();
for (size_t i = 0; i < num_layers; ++i) { for (size_t i = 0; i < num_layers; ++i) {
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, cache_lengths, input_offsets, block_tables, slot_mapping); hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
} }
return norm_->forward(hidden_states); return norm_->forward(hidden_states);
......
...@@ -48,13 +48,15 @@ public: ...@@ -48,13 +48,15 @@ public:
* @param input_ids Token IDs tensor of shape [batch, seq_len]. Batch is 1 when continuous batch is used, * @param input_ids Token IDs tensor of shape [batch, seq_len]. Batch is 1 when continuous batch is used,
* and tokens from all requests are concatenated along seq_len dimension. * and tokens from all requests are concatenated along seq_len dimension.
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param cache_lengths Cache positions tensor of shape [n_req] * @param past_sequence_lengths Cache positions tensor of shape [n_req]
* @param total_sequence_lengths Total sequence lengths tensor of shape [n_req]
* @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req + 1] * @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req + 1]
* @return Output tensor of shape [batch, seq_len, hidden_size] * @return Output tensor of shape [batch, seq_len, hidden_size]
*/ */
infinicore::Tensor forward(const infinicore::Tensor &input_ids, infinicore::Tensor forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
......
...@@ -80,28 +80,48 @@ inline void bind_infer_engine(py::module &m) { ...@@ -80,28 +80,48 @@ inline void bind_infer_engine(py::module &m) {
py::init([]( py::init([](
std::optional<infinicore::Tensor> input_ids, std::optional<infinicore::Tensor> input_ids,
std::optional<infinicore::Tensor> position_ids, std::optional<infinicore::Tensor> position_ids,
std::optional<infinicore::Tensor> cache_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping, std::optional<infinicore::Tensor> slot_mapping,
py::kwargs kwargs) { py::kwargs kwargs) {
auto input{InferEngine::Input{ InferEngine::Input input{
std::move(input_ids), std::move(input_ids),
std::move(position_ids), std::move(position_ids),
std::move(cache_lengths), std::move(past_sequence_lengths),
std::move(total_sequence_lengths),
std::move(input_offsets), std::move(input_offsets),
std::move(block_tables), std::move(block_tables),
std::move(slot_mapping)}}; std::move(slot_mapping),
};
if (kwargs) { // Explicit defaults
if (kwargs.contains("temperature")) { input.temperature = 1.0f;
input.temperature = kwargs["temperature"].cast<float>(); input.top_p = 1.0f;
} input.top_k = 1;
if (kwargs.contains("top_k")) {
input.top_k = kwargs["top_k"].cast<int>(); // Allowed keyword arguments
static const std::unordered_set<std::string> allowed_kwargs = {
"temperature",
"top_p",
"top_k",
};
for (auto &item : kwargs) {
const std::string key = py::cast<std::string>(item.first);
if (allowed_kwargs.find(key) == allowed_kwargs.end()) {
throw py::value_error(
"InferEngine.Input got an unexpected keyword argument '" + key + "'");
} }
if (kwargs.contains("top_p")) {
input.top_p = kwargs["top_p"].cast<float>(); if (key == "temperature") {
input.temperature = py::cast<float>(item.second);
} else if (key == "top_p") {
input.top_p = py::cast<float>(item.second);
} else if (key == "top_k") {
input.top_k = py::cast<int>(item.second);
} }
} }
...@@ -109,16 +129,21 @@ inline void bind_infer_engine(py::module &m) { ...@@ -109,16 +129,21 @@ inline void bind_infer_engine(py::module &m) {
}), }),
py::arg("input_ids") = std::nullopt, py::arg("input_ids") = std::nullopt,
py::arg("position_ids") = std::nullopt, py::arg("position_ids") = std::nullopt,
py::arg("cache_lengths") = std::nullopt, py::arg("past_sequence_lengths") = std::nullopt,
py::arg("total_sequence_lengths") = std::nullopt,
py::arg("input_offsets") = std::nullopt, py::arg("input_offsets") = std::nullopt,
py::arg("block_tables") = std::nullopt, py::arg("block_tables") = std::nullopt,
py::arg("slot_mapping") = std::nullopt) py::arg("slot_mapping") = std::nullopt)
.def_readwrite("input_ids", &InferEngine::Input::input_ids) .def_readwrite("input_ids", &InferEngine::Input::input_ids)
.def_readwrite("position_ids", &InferEngine::Input::position_ids) .def_readwrite("position_ids", &InferEngine::Input::position_ids)
.def_readwrite("cache_lengths", &InferEngine::Input::cache_lengths) .def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
.def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
.def_readwrite("input_offsets", &InferEngine::Input::input_offsets) .def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
.def_readwrite("block_tables", &InferEngine::Input::block_tables) .def_readwrite("block_tables", &InferEngine::Input::block_tables)
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping); .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
.def_readwrite("temperature", &InferEngine::Input::temperature)
.def_readwrite("top_k", &InferEngine::Input::top_k)
.def_readwrite("top_p", &InferEngine::Input::top_p);
py::class_<InferEngine::Output>(infer_engine, "Output") py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor"); .def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor");
......
...@@ -53,7 +53,8 @@ class InferEngine(_infinilm.InferEngine): ...@@ -53,7 +53,8 @@ class InferEngine(_infinilm.InferEngine):
input_ids, input_ids,
*, *,
position_ids=None, position_ids=None,
cache_lengths=None, past_kv_lengths=None,
total_kv_lengths=None,
input_offsets=None, input_offsets=None,
block_tables=None, block_tables=None,
slot_mapping=None, slot_mapping=None,
...@@ -64,7 +65,12 @@ class InferEngine(_infinilm.InferEngine): ...@@ -64,7 +65,12 @@ class InferEngine(_infinilm.InferEngine):
# TODO: Remove `_underlying` and simplify the corresponding code. # TODO: Remove `_underlying` and simplify the corresponding code.
input_ids = input_ids._underlying if input_ids is not None else None input_ids = input_ids._underlying if input_ids is not None else None
position_ids = position_ids._underlying if position_ids is not None else None position_ids = position_ids._underlying if position_ids is not None else None
cache_lengths = cache_lengths._underlying if cache_lengths is not None else None past_kv_lengths = (
past_kv_lengths._underlying if past_kv_lengths is not None else None
)
total_kv_lengths = (
total_kv_lengths._underlying if past_kv_lengths is not None else None
)
input_offsets = input_offsets._underlying if input_offsets is not None else None input_offsets = input_offsets._underlying if input_offsets is not None else None
block_tables = block_tables._underlying if block_tables is not None else None block_tables = block_tables._underlying if block_tables is not None else None
slot_mapping = slot_mapping._underlying if slot_mapping is not None else None slot_mapping = slot_mapping._underlying if slot_mapping is not None else None
...@@ -75,7 +81,8 @@ class InferEngine(_infinilm.InferEngine): ...@@ -75,7 +81,8 @@ class InferEngine(_infinilm.InferEngine):
super().Input( super().Input(
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
cache_lengths=cache_lengths, past_sequence_lengths=past_kv_lengths,
total_sequence_lengths=total_kv_lengths,
input_offsets=input_offsets, input_offsets=input_offsets,
block_tables=block_tables, block_tables=block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
...@@ -87,7 +94,14 @@ class InferEngine(_infinilm.InferEngine): ...@@ -87,7 +94,14 @@ class InferEngine(_infinilm.InferEngine):
.output_ids .output_ids
) )
def generate(self, input_ids, generation_config, *, _measure_and_log_time=False): def generate(
self,
input_ids,
generation_config,
*,
_measure_and_log_time=False,
paged_block_size=16,
):
if generation_config.eos_token_id is None: if generation_config.eos_token_id is None:
eos_token_id = self.config.eos_token_id eos_token_id = self.config.eos_token_id
else: else:
...@@ -119,31 +133,30 @@ class InferEngine(_infinilm.InferEngine): ...@@ -119,31 +133,30 @@ class InferEngine(_infinilm.InferEngine):
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size, list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
dtype=infinicore.int64, dtype=infinicore.int64,
) )
cache_lengths = infinicore.from_list( block_tables_list = [
[past_seq_len] * batch_size, dtype=infinicore.int64 [
) i * batch_size + b
for i in range(
(past_seq_len + seq_len + paged_block_size - 1)
// paged_block_size
)
]
for b in range(batch_size)
]
slot_mapping_list = [
(((past_seq_len + i) // paged_block_size) * batch_size + b)
* paged_block_size
+ (past_seq_len + i) % paged_block_size
for b in range(batch_size)
for i in range(seq_len)
]
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64
)
block_tables = infinicore.from_list( block_tables = infinicore.from_list(
[ block_tables_list,
[
i * batch_size + b
for i in range((past_seq_len + seq_len + 15) // 16)
]
for b in range(batch_size)
],
dtype=infinicore.int64, dtype=infinicore.int64,
) )
slot_mapping = infinicore.from_list( slot_mapping = infinicore.from_list(
[ slot_mapping_list,
((past_seq_len + i + 15) // 16) * batch_size
+ b
+ (past_seq_len + i + 15) % 16
for i in range(seq_len)
for b in range(batch_size)
],
dtype=infinicore.int64, dtype=infinicore.int64,
) )
else: else:
...@@ -155,21 +168,25 @@ class InferEngine(_infinilm.InferEngine): ...@@ -155,21 +168,25 @@ class InferEngine(_infinilm.InferEngine):
dtype=infinicore.int64, dtype=infinicore.int64,
) )
cache_lengths = infinicore.from_list(
[past_seq_len], dtype=infinicore.int64
)
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64
)
block_tables = None block_tables = None
slot_mapping = None slot_mapping = None
past_kv_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64
)
total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
)
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64
)
output_id = self( output_id = self(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cache_lengths=cache_lengths, past_kv_lengths=past_kv_lengths,
total_kv_lengths=total_kv_lengths,
input_offsets=input_offsets, input_offsets=input_offsets,
block_tables=block_tables, block_tables=block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment