Commit 693d74d3 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/143 use add_rmsnorm, nt flash attn, nt kv caching

parent b5a809a0
...@@ -85,26 +85,38 @@ StaticKVCache::update(size_t layer_idx, ...@@ -85,26 +85,38 @@ StaticKVCache::update(size_t layer_idx,
auto batch_size = k->size(0); auto batch_size = k->size(0);
auto update_len = k->size(2); auto update_len = k->size(2);
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
ASSERT_EQ(batch_size, rank_batch_size_); ASSERT_EQ(batch_size, rank_batch_size_);
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}}); auto device = k_cache_layer->device();
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
if (device.getType() == infinicore::Device::Type::NVIDIA
k_cache_update->copy_from(k); || device.getType() == infinicore::Device::Type::ILUVATAR
v_cache_update->copy_from(v); || device.getType() == infinicore::Device::Type::METAX
|| device.getType() == infinicore::Device::Type::MOORE
auto k_total = k_cache_layer->narrow({{2, 0, result_len}}); || device.getType() == infinicore::Device::Type::CAMBRICON) {
auto v_total = v_cache_layer->narrow({{2, 0, result_len}}); infinicore::op::kv_caching_(
k_cache_layer,
v_cache_layer,
k,
v,
past_sequence_lengths);
} else {
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
k_cache_update->copy_from(k);
v_cache_update->copy_from(v);
}
return {k_total, v_total}; return {k_cache_layer, v_cache_layer};
} }
// ========================== // ==========================
......
...@@ -112,8 +112,8 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -112,8 +112,8 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim] q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim]
auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor k_total; // [bs, n_kv_head, max_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor v_total; // [bs, n_kv_head, max_seq_len, head_dim]
if (kv_cache == nullptr) { if (kv_cache == nullptr) {
k_total = k_permuted; k_total = k_permuted;
v_total = v_permuted; v_total = v_permuted;
...@@ -124,27 +124,42 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -124,27 +124,42 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
} else { } else {
throw std::runtime_error("LlamaAttention: Unsupported kvcache type"); throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
} }
auto total_seq_len = k_total->shape()[2];
// 6. Compute attention infinicore::Tensor attn_output;
size_t ngroup = num_attention_heads_ / num_key_value_heads_; if (q_reshaped->device().getType() == infinicore::Device::Type::NVIDIA
auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_}); || q_reshaped->device().getType() == infinicore::Device::Type::METAX
auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_}); || q_reshaped->device().getType() == infinicore::Device::Type::MOORE
auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_}); || q_reshaped->device().getType() == infinicore::Device::Type::ILUVATAR
|| q_reshaped->device().getType() == infinicore::Device::Type::CAMBRICON) {
attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scaling_, true);
attn_output = attn_output->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
} else {
size_t total_seq_len = reinterpret_cast<int64_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len] // 6. Compute attention
size_t ngroup = num_attention_heads_ / num_key_value_heads_;
auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len] auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len}); auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim] auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
auto attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_}) auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
->permute({0, 2, 1, 3})
->contiguous() attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] ->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
}
auto output = o_proj_->forward(attn_output); auto output = o_proj_->forward(attn_output);
......
...@@ -23,38 +23,29 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, ...@@ -23,38 +23,29 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_); INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
} }
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, std::tuple<infinicore::Tensor, infinicore::Tensor>
const infinicore::Tensor &position_ids, LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states,
std::shared_ptr<infinilm::cache::Cache> kv_cache, infinicore::Tensor &residual,
std::optional<infinicore::Tensor> past_sequence_lengths, const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> total_sequence_lengths, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> input_offsets,
// Save residual for attention std::optional<infinicore::Tensor> block_tables,
auto residual = hidden_states; std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Attention layer normalization
// 1. Pre-attention layer normalization input_layernorm_->forward_inplace(hidden_states, residual);
auto normed_states = input_layernorm_->forward(hidden_states);
// 2. Self-attention
// 2. Self-attention with residual connection hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output);
// Save residual for MLP
residual = output;
// 3. Post-attention layer normalization // 3. Post-attention layer normalization
normed_states = post_attention_layernorm_->forward(output); post_attention_layernorm_->forward_inplace(hidden_states, residual);
// 4. MLP with residual connection // 4. MLP
auto mlp_output = mlp_->forward(normed_states); hidden_states = mlp_->forward(hidden_states);
// Add residual: output = output + mlp_output return std::make_tuple(hidden_states, residual);
output = infinicore::op::add(residual, mlp_output);
return output;
} }
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -41,19 +41,23 @@ public: ...@@ -41,19 +41,23 @@ public:
/** /**
* @brief Forward pass: process one decoder layer * @brief Forward pass: process one decoder layer
* *
* @param hidden_states Input tensor of shape [batch, seq_len, hidden_size] * @param hidden_states [batch, seq_len, hidden_size], will be modified
* @param residual [batch, seq_len, hidden_size], will be modified
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional KV cache for incremental decoding * @param kv_cache Optional KV cache for incremental decoding
* @return Output tensor of shape [batch, seq_len, hidden_size] * @return Output tensor of shape [batch, seq_len, hidden_size]
* Updated residual tensor of shape [batch, seq_len, hidden_size]
*/ */
infinicore::Tensor forward(const infinicore::Tensor &hidden_states, std::tuple<infinicore::Tensor, infinicore::Tensor>
const infinicore::Tensor &position_ids, forward(infinicore::Tensor &hidden_states,
std::shared_ptr<infinilm::cache::Cache> kv_cache, infinicore::Tensor &residual,
std::optional<infinicore::Tensor> past_sequence_lengths, const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> total_sequence_lengths, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> slot_mappin) const; std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
/** /**
* @brief Get the layer index * @brief Get the layer index
......
...@@ -55,11 +55,23 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -55,11 +55,23 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// 2. Process through all decoder layers // 2. Process through all decoder layers
size_t num_layers = layers_.size(); size_t num_layers = layers_.size();
infinicore::Tensor residual;
for (size_t i = 0; i < num_layers; ++i) { for (size_t i = 0; i < num_layers; ++i) {
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping); layers_.at(i)->forward(
hidden_states,
residual,
position_ids,
kv_cache_,
past_sequence_lengths,
total_sequence_lengths,
input_offsets,
block_tables,
slot_mapping);
} }
return norm_->forward(hidden_states); norm_->forward_inplace(hidden_states, residual);
return hidden_states;
} }
void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
......
import sys import sys
import os import os
import argparse
import time import time
import re import re
import csv import csv
...@@ -8,7 +7,7 @@ import numpy as np ...@@ -8,7 +7,7 @@ import numpy as np
import infinicore import infinicore
from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
from infinilm.cache import StaticKVCacheConfig from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.infer_engine import GenerationConfig, InferEngine from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig from infinilm.cache import StaticKVCacheConfig
from datasets import load_dataset, Dataset from datasets import load_dataset, Dataset
...@@ -56,6 +55,7 @@ class InfiniLMBenchmark(BaseBenchmark): ...@@ -56,6 +55,7 @@ class InfiniLMBenchmark(BaseBenchmark):
ndev=1, ndev=1,
backend="cpp", backend="cpp",
benchmark="ceval", benchmark="ceval",
enable_paged_attn=False,
): ):
import transformers import transformers
...@@ -124,7 +124,9 @@ class InfiniLMBenchmark(BaseBenchmark): ...@@ -124,7 +124,9 @@ class InfiniLMBenchmark(BaseBenchmark):
model_dir_path, model_dir_path,
device=self.device, device=self.device,
distributed_config=DistConfig(ndev), distributed_config=DistConfig(ndev),
cache_config=StaticKVCacheConfig(), cache_config=(
PagedKVCacheConfig(128) if enable_paged_attn else StaticKVCacheConfig()
),
) )
# Enable KV cache for generation # Enable KV cache for generation
...@@ -673,6 +675,7 @@ def test(): ...@@ -673,6 +675,7 @@ def test():
max_new_tokens = 500 max_new_tokens = 500
output_csv = None output_csv = None
cache_dir = None cache_dir = None
enable_paged_attn = False
i = 3 i = 3
while i < len(sys.argv): while i < len(sys.argv):
...@@ -703,6 +706,9 @@ def test(): ...@@ -703,6 +706,9 @@ def test():
elif sys.argv[i] == "--cache_dir" and i + 1 < len(sys.argv): elif sys.argv[i] == "--cache_dir" and i + 1 < len(sys.argv):
cache_dir = sys.argv[i + 1] cache_dir = sys.argv[i + 1]
i += 2 i += 2
elif sys.argv[i] == "--enable_paged_attn":
enable_paged_attn = True
i += 1
else: else:
i += 1 i += 1
...@@ -757,16 +763,13 @@ def test(): ...@@ -757,16 +763,13 @@ def test():
subject_list = ["all"] subject_list = ["all"]
# Create model based on backend (create once, reuse for all subjects) # Create model based on backend (create once, reuse for all subjects)
if backend != "010":
if backend == "torch": if backend == "torch":
model = TorchBenchmark(model_path, device_type_str, benchmark) model = TorchBenchmark(model_path, device_type_str, benchmark)
else:
model = InfiniLMBenchmark(
model_path, device_type_str, ndev, backend, benchmark
)
else: else:
print(f"test 010 backend by scripts/test_ceval.py") model = InfiniLMBenchmark(
exit(0) model_path, device_type_str, ndev, backend, benchmark, enable_paged_attn
)
# Define helper functions for loading datasets # Define helper functions for loading datasets
if benchmark == "ceval": if benchmark == "ceval":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment