Unverified Commit d09de04c authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #250 from InfiniTensor/issue/248

Issue/248 support flash-attention
parents f67956fe 5dc85bf4
......@@ -160,3 +160,21 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/
```
> 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录
- 试验中功能
- Warm Up
```bash
python examples/bench.py --nvidia --model=<model-path> --warmup
```
- Paged Attention
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn
```
- CUDA Graph
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn --enable-graph
```
- 选择attention后端 (使用flash attention后端需要先在InfiniCore完成相关配置和编译)
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn [--attn=default | --attn=flash-attn]
```
#pragma once
#include <stdexcept>
#include <string>
namespace infinilm::backends {
enum class AttentionBackend {
Default,
FlashAttn,
};
inline AttentionBackend parse_attention_backend(const std::string &backend) {
if (backend == "default") {
return AttentionBackend::Default;
}
if (backend == "flash-attn") {
return AttentionBackend::FlashAttn;
}
throw std::invalid_argument(
"Invalid attention_backend: " + backend + ". Valid options are: default, flash-attn");
}
} // namespace infinilm::backends
......@@ -101,7 +101,7 @@ StaticKVCache::update(size_t layer_idx,
v,
past_sequence_lengths);
#else
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
size_t cache_pos = reinterpret_cast<int32_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
......@@ -213,9 +213,9 @@ PagedKVCache::get_contiguous_kv(
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id) {
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64);
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I32);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I32);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I32);
auto nreq = block_tables->size(0);
auto block_tables_cpu = block_tables->to(infinicore::Device::cpu());
......@@ -227,9 +227,9 @@ PagedKVCache::get_contiguous_kv(
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
auto req = request_id;
auto cache_lens_ptr = reinterpret_cast<const int64_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int64_t *>(input_offsets_cpu->data());
int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);
auto cache_lens_ptr = reinterpret_cast<const int32_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int32_t *>(input_offsets_cpu->data());
int32_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);
auto full_k = infinicore::Tensor::empty(
{num_rank_k_heads_, (size_t)total_len, k_dim_},
......@@ -243,7 +243,7 @@ PagedKVCache::get_contiguous_kv(
size_t r = total_len % block_size_;
for (size_t b = 0; b < nblocks; b++) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));
size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));
full_k->narrow({{1, b * block_size_, block_size_}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
......@@ -252,7 +252,7 @@ PagedKVCache::get_contiguous_kv(
}
if (r > 0) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));
size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));
full_k->narrow({{1, nblocks * block_size_, r}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
......
......@@ -34,26 +34,27 @@ void PagedCompiler::compile() {
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
{nblocks}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(block_tables_holder_);
for (size_t b : decode_batch_sizes_) {
size_t block_per_req = nblocks / b;
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(input.input_ids.value());
set_zeros(input.position_ids.value());
set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.input_offsets.value());
std::vector<int64_t> input_offsets_vec(b + 1, 0);
std::vector<int32_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
std::vector<int32_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) {
input_offsets_vec[i] = i;
}
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false);
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.cu_seqlens = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
infinicore::context::memcpyH2D(input.cu_seqlens.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.slot_mapping.value());
......@@ -91,6 +92,7 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
......
......@@ -23,9 +23,11 @@ InferEngine::InferEngine(
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter
bool enable_graph_compiling,
backends::AttentionBackend attention_backend) // Changed parameter
: communication_group_(distributed_config, device_type),
legacy_model_config_(config) {
legacy_model_config_(config),
attention_backend_(attention_backend) {
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}
......@@ -39,7 +41,8 @@ InferEngine::InferEngine(
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling));
enable_graph_compiling,
attention_backend_));
}
// Compile the model on all workers
......@@ -51,8 +54,9 @@ InferEngine::InferEngine(
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter
: communication_group_(distributed_config, device_type) {
bool enable_graph_compiling,
backends::AttentionBackend attention_backend) // Changed parameter
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}
......@@ -69,7 +73,8 @@ InferEngine::InferEngine(
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling));
enable_graph_compiling,
attention_backend_));
}
// Compile the model on all workers
this->compile();
......@@ -117,6 +122,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
to_device(past_sequence_lengths), // @todo: on device in the future
to_device(total_sequence_lengths),
to_device(input_offsets),
to_device(cu_seqlens),
to_device(block_tables),
to_device(slot_mapping),
};
......@@ -169,7 +175,7 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
for (auto &worker : workers_) {
worker->wait();
}
cache_config_ = new_config->unique_copy();
this->compile();
}
......
......@@ -37,14 +37,16 @@ public:
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false);
bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
InferEngine(
const std::string &model_path = "",
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false);
bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
......@@ -73,6 +75,7 @@ protected:
std::unique_ptr<cache::CacheConfig> cache_config_;
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
};
} // namespace infinilm::engine
......@@ -26,9 +26,11 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling)
bool enable_graph_compiling,
backends::AttentionBackend attention_backend)
: legacy_model_config_(model_config),
rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT),
has_job_(false),
......@@ -53,9 +55,11 @@ RankWorker::RankWorker(
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling)
bool enable_graph_compiling,
backends::AttentionBackend attention_backend)
: model_config_(model_config),
rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT),
has_job_(false),
......@@ -234,10 +238,18 @@ void RankWorker::thread_loop() {
// Create model using factory (may be expensive)
if (model_config_ == nullptr) {
model_ = InfinilmModelFactory::createModel(legacy_model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
model_ = InfinilmModelFactory::createModel(
legacy_model_config_,
rank_info_,
pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr,
attention_backend_);
} else {
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
model_ = InfinilmModelFactory::createModel(
model_config_,
rank_info_,
pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr,
attention_backend_);
}
if (!model_) {
......@@ -339,7 +351,7 @@ void RankWorker::thread_loop() {
const auto &batch_size{logits_shape[0]};
auto n_req = local_args.input_offsets.value()->size(0) - 1;
int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data();
int32_t *input_offsets = (int32_t *)local_args.input_offsets.value()->data();
auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
......
#pragma once
#include "../backends/attention_backends.hpp"
#include "../cache/cache.hpp"
#include "../config/model_config.hpp"
#include "../models/model_factory.hpp"
......@@ -37,8 +38,10 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets;
/// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> cu_seqlens;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
......@@ -61,13 +64,15 @@ public:
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling);
bool enable_graph_compiling,
backends::AttentionBackend attention_backend);
RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling);
bool enable_graph_compiling,
backends::AttentionBackend attention_backend);
// Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name,
......@@ -107,6 +112,9 @@ private:
std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::Cache> cache_;
// Backends
backends::AttentionBackend attention_backend_;
// Graph Compiling
bool enable_graph_compiling_;
std::unique_ptr<GraphCompiler> compiler_;
......
......@@ -27,6 +27,8 @@ public:
std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets;
/// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> cu_seqlens;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
......
......@@ -4,6 +4,7 @@
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mul.hpp"
#include <algorithm>
......@@ -31,7 +32,8 @@ namespace infinilm::models::llama {
LlamaAttention::LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info)
engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: layer_idx_(layer_idx),
hidden_size_(config.hidden_size),
num_attention_heads_(config.num_attention_heads),
......@@ -41,7 +43,9 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
use_bias_(config.attention_bias),
use_output_bias_(config.attention_output_bias),
use_qk_norm_(config.qk_norm),
max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) {
max_position_embeddings_(config.max_position_embeddings),
rank_info_(rank_info),
attention_backend_(attention_backend) {
const auto &dtype{config.dtype};
int tp_rank = rank_info.tp_rank;
......@@ -75,7 +79,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info)
engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: model_config_(model_config),
layer_idx_(layer_idx),
hidden_size_(model_config->get<size_t>("hidden_size")),
......@@ -86,7 +91,8 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
use_bias_(model_config->get_or<bool>("attention_bias", true)),
use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)),
max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")),
rank_info_(rank_info) {
rank_info_(rank_info),
attention_backend_(attention_backend) {
const auto &dtype{model_config_->get_dtype()};
int tp_rank = rank_info.tp_rank;
......@@ -203,7 +209,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
} else {
size_t total_seq_len = reinterpret_cast<int64_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
size_t total_seq_len = reinterpret_cast<int32_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
......@@ -238,6 +244,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
ASSERT(block_tables.has_value());
......@@ -298,17 +305,31 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (is_prefill) {
infinicore::op::paged_attention_prefill_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
input_offsets.value(),
std::nullopt,
scaling_);
if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
infinicore::op::mha_varlen_(
attn_output,
q_reshaped,
k_total->permute({0, 2, 1, 3}),
v_total->permute({0, 2, 1, 3}),
input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt,
scaling_);
} else {
infinicore::op::paged_attention_prefill_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
input_offsets.value(),
std::nullopt,
scaling_);
}
} else {
infinicore::op::paged_attention_(
attn_output,
......@@ -322,7 +343,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
}
// 7. Project output
attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
attn_output
= attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
return o_proj_->forward(attn_output);
}
......@@ -332,6 +354,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) {
......@@ -340,7 +363,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
infinicore::Tensor output;
if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
} else {
output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths);
......
#pragma once
#include "../../backends/attention_backends.hpp"
#include "../../cache/kv_cache.hpp"
#include "../../config/model_config.hpp"
#include "../../engine/distributed/distributed.hpp"
......@@ -52,12 +53,14 @@ public:
LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/**
* @brief Forward pass: compute attention
......@@ -73,6 +76,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
......@@ -104,6 +108,7 @@ private:
std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
......@@ -132,6 +137,8 @@ private:
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_;
backends::AttentionBackend attention_backend_;
};
} // namespace infinilm::models::llama
......@@ -19,7 +19,8 @@ namespace infinilm::models::llama {
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx), rank_info_(rank_info) {
engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) : layer_idx_(layer_idx), rank_info_(rank_info) {
const auto &dtype{config.dtype};
// Initialize layer normalization layers
......@@ -29,14 +30,15 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
dtype, device);
// Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_);
INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_, attention_backend);
INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
}
LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) {
engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
// Initialize layer normalization layers
INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
......@@ -45,7 +47,7 @@ LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConf
dtype, device);
// Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_);
INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_, attention_backend);
INFINICORE_NN_MODULE_INIT(mlp, model_config_, device, rank_info_);
}
......@@ -57,13 +59,15 @@ LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Attention layer normalization
input_layernorm_->forward_inplace(hidden_states, residual);
// 2. Self-attention
hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
hidden_states = self_attn_->forward(
hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
// 3. Post-attention layer normalization
post_attention_layernorm_->forward_inplace(hidden_states, residual);
......
......@@ -48,12 +48,14 @@ public:
LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/**
* @brief Forward pass: process one decoder layer
......@@ -73,6 +75,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const;
......
......@@ -17,13 +17,14 @@ namespace infinilm::models::llama {
*/
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info) {
engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) {
// Initialize module's device_ member
device_ = device;
const auto &dtype{config.dtype};
// Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, rank_info);
INFINICORE_NN_MODULE_INIT(model, config, device, rank_info, attention_backend);
// Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens
......@@ -34,14 +35,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info) {
engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) {
// Initialize module's device_ member
device_ = device;
const auto &dtype{model_config->get_dtype()};
// Initialize base model
INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info);
INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info, attention_backend);
// Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens
// For now, we create a separate linear layer
......@@ -56,12 +58,13 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto past_sequence_lengths = input.past_sequence_lengths;
auto total_sequence_length = input.total_sequence_lengths;
auto input_offsets = input.input_offsets;
auto cu_seqlens = input.cu_seqlens;
auto block_tables = input.block_tables;
auto slot_mapping = input.slot_mapping;
// 1. Forward through base model to get hidden states
auto hidden_states = model_->forward(
input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping);
input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, cu_seqlens, block_tables, slot_mapping);
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
......
......@@ -42,11 +42,13 @@ public:
*/
LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/**
* @brief Forward pass: compute language modeling logits
......
......@@ -20,7 +20,8 @@ namespace infinilm::models::llama {
*/
LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: config_(config), rank_info_(rank_info) {
const auto &dtype{config.dtype};
// Initialize token embeddings
......@@ -34,7 +35,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
layers_.reserve(config.num_hidden_layers);
for (size_t i = 0; i < config.num_hidden_layers; ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), config, device, i, rank_info));
"layers." + std::to_string(i), config, device, i, rank_info, attention_backend));
}
// Initialize final layer normalization
......@@ -56,7 +57,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: model_config_(model_config), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
// Initialize token embeddings
......@@ -69,7 +71,7 @@ LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_conf
layers_.reserve(model_config_->get<size_t>("num_hidden_layers"));
for (size_t i = 0; i < model_config_->get<size_t>("num_hidden_layers"); ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), model_config_, device, i, rank_info));
"layers." + std::to_string(i), model_config_, device, i, rank_info, attention_backend));
}
// Initialize final layer normalization
INFINICORE_NN_MODULE_INIT(norm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
......@@ -92,6 +94,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
......@@ -109,6 +112,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
past_sequence_lengths,
total_sequence_lengths,
input_offsets,
cu_seqlens,
block_tables,
slot_mapping);
}
......
......@@ -51,11 +51,13 @@ public:
*/
LlamaModel(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/**
* @brief Forward pass: process input through the model
......@@ -73,6 +75,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const;
......
......@@ -17,12 +17,13 @@ namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) {
const cache::CacheConfig *cache,
backends::AttentionBackend attention_backend) {
std::shared_ptr<InfinilmModel> model;
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr;
model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, rank_info);
llama_config, rank_info.device, rank_info, attention_backend);
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
......@@ -37,12 +38,13 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) {
const cache::CacheConfig *cache,
backends::AttentionBackend attention_backend) {
std::shared_ptr<InfinilmModel> model;
if (true) {
model = std::make_shared<models::llama::LlamaForCausalLM>(
model_config, rank_info.device, rank_info);
model_config, rank_info.device, rank_info, attention_backend);
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
......
......@@ -3,6 +3,7 @@
#include "../config/model_config.hpp"
#include "infinilm_model.hpp"
#include "../backends/attention_backends.hpp"
#include "../engine/distributed/distributed.hpp"
namespace infinilm {
......@@ -23,11 +24,13 @@ public:
static std::shared_ptr<InfinilmModel> createModel(
const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr);
const cache::CacheConfig *cache = nullptr,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
static std::shared_ptr<InfinilmModel> createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr);
const cache::CacheConfig *cache = nullptr,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
};
} // namespace infinilm
......@@ -36,19 +36,22 @@ inline void bind_infer_engine(py::module &m) {
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) {
bool enable_graph_compiling,
const std::string &attention_backend) {
return std::make_shared<InferEngine>(
cfg,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling);
enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend));
}),
py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false)
py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default")
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
......@@ -63,11 +66,14 @@ inline void bind_infer_engine(py::module &m) {
}
return state_dict_tp_all;
})
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) {
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr;
})
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
infer_engine
......@@ -76,19 +82,22 @@ inline void bind_infer_engine(py::module &m) {
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) {
bool enable_graph_compiling,
const std::string &attention_backend) {
return std::make_shared<InferEngine>(
model_path,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling);
enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend));
}),
py::arg("model_path") = "",
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false)
py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default")
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
......@@ -103,8 +112,10 @@ inline void bind_infer_engine(py::module &m) {
}
return state_dict_tp_all;
})
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
......@@ -118,6 +129,7 @@ inline void bind_infer_engine(py::module &m) {
std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping,
py::kwargs kwargs) {
......@@ -127,6 +139,7 @@ inline void bind_infer_engine(py::module &m) {
std::move(past_sequence_lengths),
std::move(total_sequence_lengths),
std::move(input_offsets),
std::move(cu_seqlens),
std::move(block_tables),
std::move(slot_mapping),
};
......@@ -167,6 +180,7 @@ inline void bind_infer_engine(py::module &m) {
py::arg("past_sequence_lengths") = std::nullopt,
py::arg("total_sequence_lengths") = std::nullopt,
py::arg("input_offsets") = std::nullopt,
py::arg("cu_seqlens") = std::nullopt,
py::arg("block_tables") = std::nullopt,
py::arg("slot_mapping") = std::nullopt)
.def_readwrite("input_ids", &InferEngine::Input::input_ids)
......@@ -174,6 +188,7 @@ inline void bind_infer_engine(py::module &m) {
.def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
.def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
.def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
.def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens)
.def_readwrite("block_tables", &InferEngine::Input::block_tables)
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
.def_readwrite("temperature", &InferEngine::Input::temperature)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment