Commit 693d74d3 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/143 use add_rmsnorm, nt flash attn, nt kv caching

parent b5a809a0
......@@ -85,26 +85,38 @@ StaticKVCache::update(size_t layer_idx,
auto batch_size = k->size(0);
auto update_len = k->size(2);
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
ASSERT_EQ(batch_size, rank_batch_size_);
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
k_cache_update->copy_from(k);
v_cache_update->copy_from(v);
auto k_total = k_cache_layer->narrow({{2, 0, result_len}});
auto v_total = v_cache_layer->narrow({{2, 0, result_len}});
auto device = k_cache_layer->device();
if (device.getType() == infinicore::Device::Type::NVIDIA
|| device.getType() == infinicore::Device::Type::ILUVATAR
|| device.getType() == infinicore::Device::Type::METAX
|| device.getType() == infinicore::Device::Type::MOORE
|| device.getType() == infinicore::Device::Type::CAMBRICON) {
infinicore::op::kv_caching_(
k_cache_layer,
v_cache_layer,
k,
v,
past_sequence_lengths);
} else {
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
k_cache_update->copy_from(k);
v_cache_update->copy_from(v);
}
return {k_total, v_total};
return {k_cache_layer, v_cache_layer};
}
// ==========================
......
......@@ -112,8 +112,8 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim]
auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor k_total; // [bs, n_kv_head, max_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, max_seq_len, head_dim]
if (kv_cache == nullptr) {
k_total = k_permuted;
v_total = v_permuted;
......@@ -124,27 +124,42 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
} else {
throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
}
auto total_seq_len = k_total->shape()[2];
// 6. Compute attention
size_t ngroup = num_attention_heads_ / num_key_value_heads_;
auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
infinicore::Tensor attn_output;
if (q_reshaped->device().getType() == infinicore::Device::Type::NVIDIA
|| q_reshaped->device().getType() == infinicore::Device::Type::METAX
|| q_reshaped->device().getType() == infinicore::Device::Type::MOORE
|| q_reshaped->device().getType() == infinicore::Device::Type::ILUVATAR
|| q_reshaped->device().getType() == infinicore::Device::Type::CAMBRICON) {
attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scaling_, true);
attn_output = attn_output->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
} else {
size_t total_seq_len = reinterpret_cast<int64_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
// 6. Compute attention
size_t ngroup = num_attention_heads_ / num_key_value_heads_;
auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
auto attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
}
auto output = o_proj_->forward(attn_output);
......
......@@ -23,38 +23,29 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
}
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// Save residual for attention
auto residual = hidden_states;
// 1. Pre-attention layer normalization
auto normed_states = input_layernorm_->forward(hidden_states);
// 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output);
// Save residual for MLP
residual = output;
std::tuple<infinicore::Tensor, infinicore::Tensor>
LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states,
infinicore::Tensor &residual,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Attention layer normalization
input_layernorm_->forward_inplace(hidden_states, residual);
// 2. Self-attention
hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// 3. Post-attention layer normalization
normed_states = post_attention_layernorm_->forward(output);
post_attention_layernorm_->forward_inplace(hidden_states, residual);
// 4. MLP with residual connection
auto mlp_output = mlp_->forward(normed_states);
// 4. MLP
hidden_states = mlp_->forward(hidden_states);
// Add residual: output = output + mlp_output
output = infinicore::op::add(residual, mlp_output);
return output;
return std::make_tuple(hidden_states, residual);
}
} // namespace infinilm::models::llama
......@@ -41,19 +41,23 @@ public:
/**
* @brief Forward pass: process one decoder layer
*
* @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
* @param hidden_states [batch, seq_len, hidden_size], will be modified
* @param residual [batch, seq_len, hidden_size], will be modified
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional KV cache for incremental decoding
* @return Output tensor of shape [batch, seq_len, hidden_size]
* Updated residual tensor of shape [batch, seq_len, hidden_size]
*/
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
std::tuple<infinicore::Tensor, infinicore::Tensor>
forward(infinicore::Tensor &hidden_states,
infinicore::Tensor &residual,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
/**
* @brief Get the layer index
......
......@@ -55,11 +55,23 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// 2. Process through all decoder layers
size_t num_layers = layers_.size();
infinicore::Tensor residual;
for (size_t i = 0; i < num_layers; ++i) {
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
layers_.at(i)->forward(
hidden_states,
residual,
position_ids,
kv_cache_,
past_sequence_lengths,
total_sequence_lengths,
input_offsets,
block_tables,
slot_mapping);
}
return norm_->forward(hidden_states);
norm_->forward_inplace(hidden_states, residual);
return hidden_states;
}
void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
......
import sys
import os
import argparse
import time
import re
import csv
......@@ -8,7 +7,7 @@ import numpy as np
import infinicore
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.cache import StaticKVCacheConfig
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
from datasets import load_dataset, Dataset
......@@ -56,6 +55,7 @@ class InfiniLMBenchmark(BaseBenchmark):
ndev=1,
backend="cpp",
benchmark="ceval",
enable_paged_attn=False,
):
import transformers
......@@ -124,7 +124,9 @@ class InfiniLMBenchmark(BaseBenchmark):
model_dir_path,
device=self.device,
distributed_config=DistConfig(ndev),
cache_config=StaticKVCacheConfig(),
cache_config=(
PagedKVCacheConfig(128) if enable_paged_attn else StaticKVCacheConfig()
),
)
# Enable KV cache for generation
......@@ -673,6 +675,7 @@ def test():
max_new_tokens = 500
output_csv = None
cache_dir = None
enable_paged_attn = False
i = 3
while i < len(sys.argv):
......@@ -703,6 +706,9 @@ def test():
elif sys.argv[i] == "--cache_dir" and i + 1 < len(sys.argv):
cache_dir = sys.argv[i + 1]
i += 2
elif sys.argv[i] == "--enable_paged_attn":
enable_paged_attn = True
i += 1
else:
i += 1
......@@ -757,16 +763,13 @@ def test():
subject_list = ["all"]
# Create model based on backend (create once, reuse for all subjects)
if backend != "010":
if backend == "torch":
model = TorchBenchmark(model_path, device_type_str, benchmark)
else:
model = InfiniLMBenchmark(
model_path, device_type_str, ndev, backend, benchmark
)
if backend == "torch":
model = TorchBenchmark(model_path, device_type_str, benchmark)
else:
print(f"test 010 backend by scripts/test_ceval.py")
exit(0)
model = InfiniLMBenchmark(
model_path, device_type_str, ndev, backend, benchmark, enable_paged_attn
)
# Define helper functions for loading datasets
if benchmark == "ceval":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment