Commit 7668db4f authored by PanZezhong's avatar PanZezhong
Browse files

issue/248 support flash-attention lib

parent f67956fe
......@@ -34,7 +34,7 @@ void PagedCompiler::compile() {
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
{nblocks}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(block_tables_holder_);
for (size_t b : decode_batch_sizes_) {
size_t block_per_req = nblocks / b;
......@@ -47,13 +47,14 @@ void PagedCompiler::compile() {
set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.input_offsets.value());
std::vector<int64_t> input_offsets_vec(b + 1, 0);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
std::vector<int32_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) {
input_offsets_vec[i] = i;
}
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false);
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.cu_seqlens = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
infinicore::context::memcpyH2D(input.cu_seqlens.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.slot_mapping.value());
......@@ -91,6 +92,7 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
......
......@@ -117,6 +117,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
to_device(past_sequence_lengths), // @todo: on device in the future
to_device(total_sequence_lengths),
to_device(input_offsets),
to_device(cu_seqlens),
to_device(block_tables),
to_device(slot_mapping),
};
......
......@@ -339,7 +339,7 @@ void RankWorker::thread_loop() {
const auto &batch_size{logits_shape[0]};
auto n_req = local_args.input_offsets.value()->size(0) - 1;
int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data();
int32_t *input_offsets = (int32_t *)local_args.input_offsets.value()->data();
auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
......
......@@ -37,8 +37,10 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets;
/// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> cu_seqlens;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
......
......@@ -27,6 +27,8 @@ public:
std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets;
/// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> cu_seqlens;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
......
......@@ -4,6 +4,7 @@
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mul.hpp"
#include <algorithm>
......@@ -238,6 +239,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
ASSERT(block_tables.has_value());
......@@ -297,32 +299,46 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
// 6. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (is_prefill) {
infinicore::op::paged_attention_prefill_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
input_offsets.value(),
std::nullopt,
scaling_);
} else {
infinicore::op::paged_attention_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
std::nullopt,
scaling_);
}
// if (is_prefill) {
// infinicore::op::paged_attention_prefill_(
// attn_output,
// q_reshaped,
// k_total,
// v_total,
// block_tables.value(),
// total_sequence_lengths.value(),
// input_offsets.value(),
// std::nullopt,
// scaling_);
// } else {
// infinicore::op::paged_attention_(
// attn_output,
// q_reshaped,
// k_total,
// v_total,
// block_tables.value(),
// total_sequence_lengths.value(),
// std::nullopt,
// scaling_);
// }
infinicore::op::mha_varlen_(
attn_output,
q_reshaped,
k_total->permute({0, 2, 1, 3}),
v_total->permute({0, 2, 1, 3}),
input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt,
scaling_);
// 7. Project output
attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
attn_output
= attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
return o_proj_->forward(attn_output);
}
......@@ -332,6 +348,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) {
......@@ -340,7 +357,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
infinicore::Tensor output;
if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
} else {
output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths);
......
......@@ -73,6 +73,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
......@@ -104,6 +105,7 @@ private:
std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
......
......@@ -57,13 +57,15 @@ LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Attention layer normalization
input_layernorm_->forward_inplace(hidden_states, residual);
// 2. Self-attention
hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
hidden_states = self_attn_->forward(
hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
// 3. Post-attention layer normalization
post_attention_layernorm_->forward_inplace(hidden_states, residual);
......
......@@ -73,6 +73,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
......
......@@ -56,12 +56,13 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto past_sequence_lengths = input.past_sequence_lengths;
auto total_sequence_length = input.total_sequence_lengths;
auto input_offsets = input.input_offsets;
auto cu_seqlens = input.cu_seqlens;
auto block_tables = input.block_tables;
auto slot_mapping = input.slot_mapping;
// 1. Forward through base model to get hidden states
auto hidden_states = model_->forward(
input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping);
input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, cu_seqlens, block_tables, slot_mapping);
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
......
......@@ -92,6 +92,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
......@@ -109,6 +110,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
past_sequence_lengths,
total_sequence_lengths,
input_offsets,
cu_seqlens,
block_tables,
slot_mapping);
}
......
......@@ -73,6 +73,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
......
......@@ -118,6 +118,7 @@ inline void bind_infer_engine(py::module &m) {
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping,
py::kwargs kwargs) {
......@@ -127,6 +128,7 @@ inline void bind_infer_engine(py::module &m) {
std::move(past_sequence_lengths),
std::move(total_sequence_lengths),
std::move(input_offsets),
std::move(cu_seqlens),
std::move(block_tables),
std::move(slot_mapping),
};
......@@ -167,6 +169,7 @@ inline void bind_infer_engine(py::module &m) {
py::arg("past_sequence_lengths") = std::nullopt,
py::arg("total_sequence_lengths") = std::nullopt,
py::arg("input_offsets") = std::nullopt,
py::arg("cu_seqlens") = std::nullopt,
py::arg("block_tables") = std::nullopt,
py::arg("slot_mapping") = std::nullopt)
.def_readwrite("input_ids", &InferEngine::Input::input_ids)
......@@ -174,6 +177,7 @@ inline void bind_infer_engine(py::module &m) {
.def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
.def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
.def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
.def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens)
.def_readwrite("block_tables", &InferEngine::Input::block_tables)
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
.def_readwrite("temperature", &InferEngine::Input::temperature)
......
......@@ -57,6 +57,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=None,
total_kv_lengths=None,
input_offsets=None,
cu_seqlens=None,
block_tables=None,
slot_mapping=None,
temperature=None,
......@@ -74,6 +75,7 @@ class InferEngine(_infinilm.InferEngine):
)
input_offsets = input_offsets._underlying if input_offsets is not None else None
block_tables = block_tables._underlying if block_tables is not None else None
cu_seqlens = cu_seqlens._underlying if cu_seqlens is not None else None
slot_mapping = slot_mapping._underlying if slot_mapping is not None else None
return infinicore.Tensor(
......@@ -85,6 +87,7 @@ class InferEngine(_infinilm.InferEngine):
past_sequence_lengths=past_kv_lengths,
total_sequence_lengths=total_kv_lengths,
input_offsets=input_offsets,
cu_seqlens=cu_seqlens,
block_tables=block_tables,
slot_mapping=slot_mapping,
temperature=temperature,
......@@ -135,7 +138,7 @@ class InferEngine(_infinilm.InferEngine):
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
dtype=infinicore.int32,
)
for iter in range(0, generation_config.max_new_tokens):
......@@ -193,9 +196,11 @@ class InferEngine(_infinilm.InferEngine):
total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
)
cu_seqlens = infinicore.from_list(
[(past_seq_len + seq_len) * i for i in range(batch_size + 1)], dtype=infinicore.int32
)
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32
)
output_id = self(
......@@ -204,6 +209,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=past_kv_lengths,
total_kv_lengths=total_kv_lengths,
input_offsets=input_offsets,
cu_seqlens = cu_seqlens,
block_tables=block_tables,
slot_mapping=slot_mapping,
temperature=generation_config.temperature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment