"vscode:/vscode.git/clone" did not exist on "96ddffe8fdabccb10fb693a54dcb88bd5b71bc09"
Commit 8297a0b7 authored by PanZezhong's avatar PanZezhong
Browse files

issue/248 optimize: use flash-attn only in prefill

parent ae210024
......@@ -41,12 +41,12 @@ void PagedCompiler::compile() {
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(input.input_ids.value());
set_zeros(input.position_ids.value());
set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
std::vector<int32_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
std::vector<int32_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) {
......
......@@ -304,6 +304,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
// 6. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (is_prefill) {
if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
infinicore::op::mha_varlen_(
attn_output,
......@@ -318,7 +319,6 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
std::nullopt,
scaling_);
} else {
if (is_prefill) {
infinicore::op::paged_attention_prefill_(
attn_output,
q_reshaped,
......@@ -329,7 +329,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
input_offsets.value(),
std::nullopt,
scaling_);
}
} else {
infinicore::op::paged_attention_(
attn_output,
......@@ -341,7 +341,6 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
std::nullopt,
scaling_);
}
}
// 7. Project output
attn_output
......
......@@ -193,10 +193,10 @@ class InferEngine(_infinilm.InferEngine):
slot_mapping = None
past_kv_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64
[past_seq_len] * batch_size, dtype=infinicore.int32
)
total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int32
)
cu_seqlens = infinicore.from_list(
[(past_seq_len + seq_len) * i for i in range(batch_size + 1)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment