Commit 8297a0b7 authored by PanZezhong's avatar PanZezhong
Browse files

issue/248 optimize: use flash-attn only in prefill

parent ae210024
......@@ -41,12 +41,12 @@ void PagedCompiler::compile() {
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(input.input_ids.value());
set_zeros(input.position_ids.value());
set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
std::vector<int32_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
std::vector<int32_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) {
......
......@@ -304,43 +304,42 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
// 6. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
infinicore::op::mha_varlen_(
attn_output,
q_reshaped,
k_total->permute({0, 2, 1, 3}),
v_total->permute({0, 2, 1, 3}),
input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt,
scaling_);
} else {
if (is_prefill) {
infinicore::op::paged_attention_prefill_(
if (is_prefill) {
if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
infinicore::op::mha_varlen_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
k_total->permute({0, 2, 1, 3}),
v_total->permute({0, 2, 1, 3}),
input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt,
scaling_);
} else {
infinicore::op::paged_attention_(
infinicore::op::paged_attention_prefill_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
input_offsets.value(),
std::nullopt,
scaling_);
}
} else {
infinicore::op::paged_attention_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
std::nullopt,
scaling_);
}
// 7. Project output
......
......@@ -193,10 +193,10 @@ class InferEngine(_infinilm.InferEngine):
slot_mapping = None
past_kv_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64
[past_seq_len] * batch_size, dtype=infinicore.int32
)
total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int32
)
cu_seqlens = infinicore.from_list(
[(past_seq_len + seq_len) * i for i in range(batch_size + 1)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment