Commit e8245b7d authored by wooway777's avatar wooway777
Browse files

issue/116 - using batched rope

parent 81081f3c
......@@ -94,12 +94,17 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
throw std::runtime_error("Unexpected position_ids shape");
}
// 4. Prepare KV caches
// 4. Apply RoPE to Q and K
auto q_rope = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3});
rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim]
// 5. Prepare KV caches
// Convert to [batch, n_head, seq_len, head_dim] for cache
// Ensure contiguous after permute for F16 compatibility with cache operations
q_reshaped = q_reshaped->permute({0, 2, 1, 3})->contiguous(); // [bs, n_q_head, seq_len, head_dim]
auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim]
auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
infinilm::cache::DynamicCache *external_cache = static_cast<infinilm::cache::DynamicCache *>(kv_cache);
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
......@@ -113,12 +118,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
}
auto total_seq_len = k_total->shape()[2];
// 5. Apply RoPE to full batch
auto q_rope = q_reshaped->view({batch_size * num_attention_heads_, seq_len, head_dim_})->permute({1, 0, 2}); // [seq_len, bs * n_q_head, head_dim]
auto k_rope = k_total->narrow({{2, total_seq_len - seq_len, seq_len}})->view({batch_size * num_key_value_heads_, seq_len, head_dim_})->permute({1, 0, 2}); // [seq_len, bs * n_kv_head, head_dim]
rotary_emb_->forward(q_rope, pos_ids_for_rope, true);
rotary_emb_->forward(k_rope, pos_ids_for_rope, true);
// 6. Compute attention
size_t ngroup = num_attention_heads_ / num_key_value_heads_;
auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment