Commit 7668db4f authored by PanZezhong's avatar PanZezhong
Browse files

issue/248 support flash-attention lib

parent f67956fe
...@@ -34,7 +34,7 @@ void PagedCompiler::compile() { ...@@ -34,7 +34,7 @@ void PagedCompiler::compile() {
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end()); size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear(); compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty( block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice()); {nblocks}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(block_tables_holder_); set_zeros(block_tables_holder_);
for (size_t b : decode_batch_sizes_) { for (size_t b : decode_batch_sizes_) {
size_t block_per_req = nblocks / b; size_t block_per_req = nblocks / b;
...@@ -47,13 +47,14 @@ void PagedCompiler::compile() { ...@@ -47,13 +47,14 @@ void PagedCompiler::compile() {
set_zeros(input.total_sequence_lengths.value()); set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1); std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice()); input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(input.input_offsets.value()); std::vector<int32_t> input_offsets_vec(b + 1, 0);
std::vector<int64_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) { for (size_t i = 0; i <= b; i++) {
input_offsets_vec[i] = i; input_offsets_vec[i] = i;
} }
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false); infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.cu_seqlens = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
infinicore::context::memcpyH2D(input.cu_seqlens.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1}); input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.slot_mapping.value()); set_zeros(input.slot_mapping.value());
...@@ -91,6 +92,7 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input & ...@@ -91,6 +92,7 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &
graph_input.position_ids.value()->copy_from(input.position_ids.value()); graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
......
...@@ -117,6 +117,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { ...@@ -117,6 +117,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
to_device(past_sequence_lengths), // @todo: on device in the future to_device(past_sequence_lengths), // @todo: on device in the future
to_device(total_sequence_lengths), to_device(total_sequence_lengths),
to_device(input_offsets), to_device(input_offsets),
to_device(cu_seqlens),
to_device(block_tables), to_device(block_tables),
to_device(slot_mapping), to_device(slot_mapping),
}; };
......
...@@ -339,7 +339,7 @@ void RankWorker::thread_loop() { ...@@ -339,7 +339,7 @@ void RankWorker::thread_loop() {
const auto &batch_size{logits_shape[0]}; const auto &batch_size{logits_shape[0]};
auto n_req = local_args.input_offsets.value()->size(0) - 1; auto n_req = local_args.input_offsets.value()->size(0) - 1;
int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data(); int32_t *input_offsets = (int32_t *)local_args.input_offsets.value()->data();
auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)}; auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
......
...@@ -37,8 +37,10 @@ public: ...@@ -37,8 +37,10 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths; std::optional<infinicore::Tensor> past_sequence_lengths;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`. /// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> total_sequence_lengths; std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`. /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets; std::optional<infinicore::Tensor> input_offsets;
/// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> cu_seqlens;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables; std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache. /// Slot ids for each token `[seq]`. Used for paged cache.
......
...@@ -27,6 +27,8 @@ public: ...@@ -27,6 +27,8 @@ public:
std::optional<infinicore::Tensor> total_sequence_lengths; std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets; std::optional<infinicore::Tensor> input_offsets;
/// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> cu_seqlens;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables; std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache. /// Slot ids for each token `[seq]`. Used for paged cache.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp" #include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mul.hpp" #include "infinicore/ops/mul.hpp"
#include <algorithm> #include <algorithm>
...@@ -238,6 +239,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -238,6 +239,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache, std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
ASSERT(block_tables.has_value()); ASSERT(block_tables.has_value());
...@@ -297,32 +299,46 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -297,32 +299,46 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
// 6. Compute attention // 6. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (is_prefill) { // if (is_prefill) {
infinicore::op::paged_attention_prefill_( // infinicore::op::paged_attention_prefill_(
// attn_output,
// q_reshaped,
// k_total,
// v_total,
// block_tables.value(),
// total_sequence_lengths.value(),
// input_offsets.value(),
// std::nullopt,
// scaling_);
// } else {
// infinicore::op::paged_attention_(
// attn_output,
// q_reshaped,
// k_total,
// v_total,
// block_tables.value(),
// total_sequence_lengths.value(),
// std::nullopt,
// scaling_);
// }
infinicore::op::mha_varlen_(
attn_output, attn_output,
q_reshaped, q_reshaped,
k_total, k_total->permute({0, 2, 1, 3}),
v_total, v_total->permute({0, 2, 1, 3}),
block_tables.value(),
total_sequence_lengths.value(),
input_offsets.value(), input_offsets.value(),
std::nullopt, cu_seqlens.value(),
scaling_);
} else {
infinicore::op::paged_attention_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(), block_tables.value(),
total_sequence_lengths.value(), max_position_embeddings_,
max_position_embeddings_,
std::nullopt, std::nullopt,
scaling_); scaling_);
}
// 7. Project output // 7. Project output
attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_}); attn_output
= attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
return o_proj_->forward(attn_output); return o_proj_->forward(attn_output);
} }
...@@ -332,6 +348,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -332,6 +348,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) { if (!rotary_emb_) {
...@@ -340,7 +357,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -340,7 +357,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
infinicore::Tensor output; infinicore::Tensor output;
if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) { if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, block_tables, slot_mapping); output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
} else { } else {
output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths); output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths);
......
...@@ -73,6 +73,7 @@ public: ...@@ -73,6 +73,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
...@@ -104,6 +105,7 @@ private: ...@@ -104,6 +105,7 @@ private:
std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache, std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
......
...@@ -57,13 +57,15 @@ LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states, ...@@ -57,13 +57,15 @@ LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states,
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Attention layer normalization // 1. Attention layer normalization
input_layernorm_->forward_inplace(hidden_states, residual); input_layernorm_->forward_inplace(hidden_states, residual);
// 2. Self-attention // 2. Self-attention
hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping); hidden_states = self_attn_->forward(
hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
// 3. Post-attention layer normalization // 3. Post-attention layer normalization
post_attention_layernorm_->forward_inplace(hidden_states, residual); post_attention_layernorm_->forward_inplace(hidden_states, residual);
......
...@@ -73,6 +73,7 @@ public: ...@@ -73,6 +73,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const; std::optional<infinicore::Tensor> slot_mappin) const;
......
...@@ -56,12 +56,13 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { ...@@ -56,12 +56,13 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto past_sequence_lengths = input.past_sequence_lengths; auto past_sequence_lengths = input.past_sequence_lengths;
auto total_sequence_length = input.total_sequence_lengths; auto total_sequence_length = input.total_sequence_lengths;
auto input_offsets = input.input_offsets; auto input_offsets = input.input_offsets;
auto cu_seqlens = input.cu_seqlens;
auto block_tables = input.block_tables; auto block_tables = input.block_tables;
auto slot_mapping = input.slot_mapping; auto slot_mapping = input.slot_mapping;
// 1. Forward through base model to get hidden states // 1. Forward through base model to get hidden states
auto hidden_states = model_->forward( auto hidden_states = model_->forward(
input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping); input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, cu_seqlens, block_tables, slot_mapping);
// 2. Apply language modeling head to get logits // 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states); auto logits = lm_head_->forward(hidden_states);
......
...@@ -92,6 +92,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -92,6 +92,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
...@@ -109,6 +110,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -109,6 +110,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
past_sequence_lengths, past_sequence_lengths,
total_sequence_lengths, total_sequence_lengths,
input_offsets, input_offsets,
cu_seqlens,
block_tables, block_tables,
slot_mapping); slot_mapping);
} }
......
...@@ -73,6 +73,7 @@ public: ...@@ -73,6 +73,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
......
...@@ -118,6 +118,7 @@ inline void bind_infer_engine(py::module &m) { ...@@ -118,6 +118,7 @@ inline void bind_infer_engine(py::module &m) {
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping, std::optional<infinicore::Tensor> slot_mapping,
py::kwargs kwargs) { py::kwargs kwargs) {
...@@ -127,6 +128,7 @@ inline void bind_infer_engine(py::module &m) { ...@@ -127,6 +128,7 @@ inline void bind_infer_engine(py::module &m) {
std::move(past_sequence_lengths), std::move(past_sequence_lengths),
std::move(total_sequence_lengths), std::move(total_sequence_lengths),
std::move(input_offsets), std::move(input_offsets),
std::move(cu_seqlens),
std::move(block_tables), std::move(block_tables),
std::move(slot_mapping), std::move(slot_mapping),
}; };
...@@ -167,6 +169,7 @@ inline void bind_infer_engine(py::module &m) { ...@@ -167,6 +169,7 @@ inline void bind_infer_engine(py::module &m) {
py::arg("past_sequence_lengths") = std::nullopt, py::arg("past_sequence_lengths") = std::nullopt,
py::arg("total_sequence_lengths") = std::nullopt, py::arg("total_sequence_lengths") = std::nullopt,
py::arg("input_offsets") = std::nullopt, py::arg("input_offsets") = std::nullopt,
py::arg("cu_seqlens") = std::nullopt,
py::arg("block_tables") = std::nullopt, py::arg("block_tables") = std::nullopt,
py::arg("slot_mapping") = std::nullopt) py::arg("slot_mapping") = std::nullopt)
.def_readwrite("input_ids", &InferEngine::Input::input_ids) .def_readwrite("input_ids", &InferEngine::Input::input_ids)
...@@ -174,6 +177,7 @@ inline void bind_infer_engine(py::module &m) { ...@@ -174,6 +177,7 @@ inline void bind_infer_engine(py::module &m) {
.def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths) .def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
.def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths) .def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
.def_readwrite("input_offsets", &InferEngine::Input::input_offsets) .def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
.def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens)
.def_readwrite("block_tables", &InferEngine::Input::block_tables) .def_readwrite("block_tables", &InferEngine::Input::block_tables)
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping) .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
.def_readwrite("temperature", &InferEngine::Input::temperature) .def_readwrite("temperature", &InferEngine::Input::temperature)
......
...@@ -57,6 +57,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -57,6 +57,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=None, past_kv_lengths=None,
total_kv_lengths=None, total_kv_lengths=None,
input_offsets=None, input_offsets=None,
cu_seqlens=None,
block_tables=None, block_tables=None,
slot_mapping=None, slot_mapping=None,
temperature=None, temperature=None,
...@@ -74,6 +75,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -74,6 +75,7 @@ class InferEngine(_infinilm.InferEngine):
) )
input_offsets = input_offsets._underlying if input_offsets is not None else None input_offsets = input_offsets._underlying if input_offsets is not None else None
block_tables = block_tables._underlying if block_tables is not None else None block_tables = block_tables._underlying if block_tables is not None else None
cu_seqlens = cu_seqlens._underlying if cu_seqlens is not None else None
slot_mapping = slot_mapping._underlying if slot_mapping is not None else None slot_mapping = slot_mapping._underlying if slot_mapping is not None else None
return infinicore.Tensor( return infinicore.Tensor(
...@@ -85,6 +87,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -85,6 +87,7 @@ class InferEngine(_infinilm.InferEngine):
past_sequence_lengths=past_kv_lengths, past_sequence_lengths=past_kv_lengths,
total_sequence_lengths=total_kv_lengths, total_sequence_lengths=total_kv_lengths,
input_offsets=input_offsets, input_offsets=input_offsets,
cu_seqlens=cu_seqlens,
block_tables=block_tables, block_tables=block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
temperature=temperature, temperature=temperature,
...@@ -135,7 +138,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -135,7 +138,7 @@ class InferEngine(_infinilm.InferEngine):
] ]
block_tables = infinicore.from_list( block_tables = infinicore.from_list(
block_tables_list, block_tables_list,
dtype=infinicore.int64, dtype=infinicore.int32,
) )
for iter in range(0, generation_config.max_new_tokens): for iter in range(0, generation_config.max_new_tokens):
...@@ -193,9 +196,11 @@ class InferEngine(_infinilm.InferEngine): ...@@ -193,9 +196,11 @@ class InferEngine(_infinilm.InferEngine):
total_kv_lengths = infinicore.from_list( total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64 [past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
) )
cu_seqlens = infinicore.from_list(
[(past_seq_len + seq_len) * i for i in range(batch_size + 1)], dtype=infinicore.int32
)
input_offsets = infinicore.from_list( input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64 [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32
) )
output_id = self( output_id = self(
...@@ -204,6 +209,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -204,6 +209,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=past_kv_lengths, past_kv_lengths=past_kv_lengths,
total_kv_lengths=total_kv_lengths, total_kv_lengths=total_kv_lengths,
input_offsets=input_offsets, input_offsets=input_offsets,
cu_seqlens = cu_seqlens,
block_tables=block_tables, block_tables=block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
temperature=generation_config.temperature, temperature=generation_config.temperature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment