Commit 8297a0b7 authored by PanZezhong's avatar PanZezhong
Browse files

issue/248 optimize: use flash-attn only in prefill

parent ae210024
...@@ -41,12 +41,12 @@ void PagedCompiler::compile() { ...@@ -41,12 +41,12 @@ void PagedCompiler::compile() {
InfinilmModel::Input input; InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(input.input_ids.value()); set_zeros(input.input_ids.value());
set_zeros(input.position_ids.value()); set_zeros(input.position_ids.value());
set_zeros(input.total_sequence_lengths.value()); set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1); std::vector<int32_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
std::vector<int32_t> input_offsets_vec(b + 1, 0); std::vector<int32_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) { for (size_t i = 0; i <= b; i++) {
......
...@@ -304,43 +304,42 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -304,43 +304,42 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
// 6. Compute attention // 6. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (attention_backend_ == backends::AttentionBackend::FlashAttn) { if (is_prefill) {
infinicore::op::mha_varlen_( if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
attn_output, infinicore::op::mha_varlen_(
q_reshaped,
k_total->permute({0, 2, 1, 3}),
v_total->permute({0, 2, 1, 3}),
input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt,
scaling_);
} else {
if (is_prefill) {
infinicore::op::paged_attention_prefill_(
attn_output, attn_output,
q_reshaped, q_reshaped,
k_total, k_total->permute({0, 2, 1, 3}),
v_total, v_total->permute({0, 2, 1, 3}),
block_tables.value(),
total_sequence_lengths.value(),
input_offsets.value(), input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt, std::nullopt,
scaling_); scaling_);
} else { } else {
infinicore::op::paged_attention_( infinicore::op::paged_attention_prefill_(
attn_output, attn_output,
q_reshaped, q_reshaped,
k_total, k_total,
v_total, v_total,
block_tables.value(), block_tables.value(),
total_sequence_lengths.value(), total_sequence_lengths.value(),
input_offsets.value(),
std::nullopt, std::nullopt,
scaling_); scaling_);
} }
} else {
infinicore::op::paged_attention_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
std::nullopt,
scaling_);
} }
// 7. Project output // 7. Project output
......
...@@ -193,10 +193,10 @@ class InferEngine(_infinilm.InferEngine): ...@@ -193,10 +193,10 @@ class InferEngine(_infinilm.InferEngine):
slot_mapping = None slot_mapping = None
past_kv_lengths = infinicore.from_list( past_kv_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64 [past_seq_len] * batch_size, dtype=infinicore.int32
) )
total_kv_lengths = infinicore.from_list( total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64 [past_seq_len + seq_len] * batch_size, dtype=infinicore.int32
) )
cu_seqlens = infinicore.from_list( cu_seqlens = infinicore.from_list(
[(past_seq_len + seq_len) * i for i in range(batch_size + 1)], [(past_seq_len + seq_len) * i for i in range(batch_size + 1)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment