Unverified Commit d09de04c authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #250 from InfiniTensor/issue/248

Issue/248 support flash-attention
parents f67956fe 5dc85bf4
...@@ -160,3 +160,21 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA ...@@ -160,3 +160,21 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/ python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/
``` ```
> 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录 > 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录
- 试验中功能
- Warm Up
```bash
python examples/bench.py --nvidia --model=<model-path> --warmup
```
- Paged Attention
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn
```
- CUDA Graph
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn --enable-graph
```
- 选择attention后端 (使用flash attention后端需要先在InfiniCore完成相关配置和编译)
```bash
python examples/bench.py --nvidia --model=<model-path> --enable-paged-attn [--attn=default | --attn=flash-attn]
```
#pragma once
#include <stdexcept>
#include <string>
namespace infinilm::backends {
enum class AttentionBackend {
Default,
FlashAttn,
};
inline AttentionBackend parse_attention_backend(const std::string &backend) {
if (backend == "default") {
return AttentionBackend::Default;
}
if (backend == "flash-attn") {
return AttentionBackend::FlashAttn;
}
throw std::invalid_argument(
"Invalid attention_backend: " + backend + ". Valid options are: default, flash-attn");
}
} // namespace infinilm::backends
...@@ -101,7 +101,7 @@ StaticKVCache::update(size_t layer_idx, ...@@ -101,7 +101,7 @@ StaticKVCache::update(size_t layer_idx,
v, v,
past_sequence_lengths); past_sequence_lengths);
#else #else
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]; size_t cache_pos = reinterpret_cast<int32_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len; auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_); ASSERT(result_len <= cache_len_);
...@@ -213,9 +213,9 @@ PagedKVCache::get_contiguous_kv( ...@@ -213,9 +213,9 @@ PagedKVCache::get_contiguous_kv(
const infinicore::Tensor cache_lens, const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets, const infinicore::Tensor input_offsets,
size_t request_id) { size_t request_id) {
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64); ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I32);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64); ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I32);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64); ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I32);
auto nreq = block_tables->size(0); auto nreq = block_tables->size(0);
auto block_tables_cpu = block_tables->to(infinicore::Device::cpu()); auto block_tables_cpu = block_tables->to(infinicore::Device::cpu());
...@@ -227,9 +227,9 @@ PagedKVCache::get_contiguous_kv( ...@@ -227,9 +227,9 @@ PagedKVCache::get_contiguous_kv(
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx); auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
auto req = request_id; auto req = request_id;
auto cache_lens_ptr = reinterpret_cast<const int64_t *>(cache_lens_cpu->data()); auto cache_lens_ptr = reinterpret_cast<const int32_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int64_t *>(input_offsets_cpu->data()); auto input_offsets_ptr = reinterpret_cast<const int32_t *>(input_offsets_cpu->data());
int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]); int32_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);
auto full_k = infinicore::Tensor::empty( auto full_k = infinicore::Tensor::empty(
{num_rank_k_heads_, (size_t)total_len, k_dim_}, {num_rank_k_heads_, (size_t)total_len, k_dim_},
...@@ -243,7 +243,7 @@ PagedKVCache::get_contiguous_kv( ...@@ -243,7 +243,7 @@ PagedKVCache::get_contiguous_kv(
size_t r = total_len % block_size_; size_t r = total_len % block_size_;
for (size_t b = 0; b < nblocks; b++) { for (size_t b = 0; b < nblocks; b++) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data())); size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));
full_k->narrow({{1, b * block_size_, block_size_}}) full_k->narrow({{1, b * block_size_, block_size_}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)); ->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
...@@ -252,7 +252,7 @@ PagedKVCache::get_contiguous_kv( ...@@ -252,7 +252,7 @@ PagedKVCache::get_contiguous_kv(
} }
if (r > 0) { if (r > 0) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data())); size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));
full_k->narrow({{1, nblocks * block_size_, r}}) full_k->narrow({{1, nblocks * block_size_, r}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}})); ->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
......
...@@ -34,26 +34,27 @@ void PagedCompiler::compile() { ...@@ -34,26 +34,27 @@ void PagedCompiler::compile() {
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end()); size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear(); compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty( block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice()); {nblocks}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(block_tables_holder_); set_zeros(block_tables_holder_);
for (size_t b : decode_batch_sizes_) { for (size_t b : decode_batch_sizes_) {
size_t block_per_req = nblocks / b; size_t block_per_req = nblocks / b;
InfinilmModel::Input input; InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(input.input_ids.value()); set_zeros(input.input_ids.value());
set_zeros(input.position_ids.value()); set_zeros(input.position_ids.value());
set_zeros(input.total_sequence_lengths.value()); set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1); std::vector<int32_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice()); input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(input.input_offsets.value()); std::vector<int32_t> input_offsets_vec(b + 1, 0);
std::vector<int64_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) { for (size_t i = 0; i <= b; i++) {
input_offsets_vec[i] = i; input_offsets_vec[i] = i;
} }
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false); infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.cu_seqlens = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
infinicore::context::memcpyH2D(input.cu_seqlens.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1}); input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.slot_mapping.value()); set_zeros(input.slot_mapping.value());
...@@ -91,6 +92,7 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input & ...@@ -91,6 +92,7 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &
graph_input.position_ids.value()->copy_from(input.position_ids.value()); graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
......
...@@ -23,9 +23,11 @@ InferEngine::InferEngine( ...@@ -23,9 +23,11 @@ InferEngine::InferEngine(
const distributed::DistConfig &distributed_config, const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type, infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter bool enable_graph_compiling,
backends::AttentionBackend attention_backend) // Changed parameter
: communication_group_(distributed_config, device_type), : communication_group_(distributed_config, device_type),
legacy_model_config_(config) { legacy_model_config_(config),
attention_backend_(attention_backend) {
if (cache_config != nullptr) { if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy(); cache_config_ = cache_config->unique_copy();
} }
...@@ -39,7 +41,8 @@ InferEngine::InferEngine( ...@@ -39,7 +41,8 @@ InferEngine::InferEngine(
communication_group_.get_rank_info(r), communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr, cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(), barrier_.get(),
enable_graph_compiling)); enable_graph_compiling,
attention_backend_));
} }
// Compile the model on all workers // Compile the model on all workers
...@@ -51,8 +54,9 @@ InferEngine::InferEngine( ...@@ -51,8 +54,9 @@ InferEngine::InferEngine(
const distributed::DistConfig &distributed_config, const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type, infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter bool enable_graph_compiling,
: communication_group_(distributed_config, device_type) { backends::AttentionBackend attention_backend) // Changed parameter
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
if (cache_config != nullptr) { if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy(); cache_config_ = cache_config->unique_copy();
} }
...@@ -69,7 +73,8 @@ InferEngine::InferEngine( ...@@ -69,7 +73,8 @@ InferEngine::InferEngine(
communication_group_.get_rank_info(r), communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr, cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(), barrier_.get(),
enable_graph_compiling)); enable_graph_compiling,
attention_backend_));
} }
// Compile the model on all workers // Compile the model on all workers
this->compile(); this->compile();
...@@ -117,6 +122,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { ...@@ -117,6 +122,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
to_device(past_sequence_lengths), // @todo: on device in the future to_device(past_sequence_lengths), // @todo: on device in the future
to_device(total_sequence_lengths), to_device(total_sequence_lengths),
to_device(input_offsets), to_device(input_offsets),
to_device(cu_seqlens),
to_device(block_tables), to_device(block_tables),
to_device(slot_mapping), to_device(slot_mapping),
}; };
...@@ -169,7 +175,7 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) { ...@@ -169,7 +175,7 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->wait(); worker->wait();
} }
cache_config_ = new_config->unique_copy();
this->compile(); this->compile();
} }
......
...@@ -37,14 +37,16 @@ public: ...@@ -37,14 +37,16 @@ public:
const distributed::DistConfig &distributed_config = distributed::DistConfig(), const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr, const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false); bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
InferEngine( InferEngine(
const std::string &model_path = "", const std::string &model_path = "",
const distributed::DistConfig &distributed_config = distributed::DistConfig(), const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr, const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false); bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
// Load a parameter to all workers (each can extract its shard inside RankWorker) // Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param); void load_param(const std::string &name, const infinicore::Tensor &param);
...@@ -73,6 +75,7 @@ protected: ...@@ -73,6 +75,7 @@ protected:
std::unique_ptr<cache::CacheConfig> cache_config_; std::unique_ptr<cache::CacheConfig> cache_config_;
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config(); const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::config::ModelConfig> model_config_; std::shared_ptr<infinilm::config::ModelConfig> model_config_;
backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
}; };
} // namespace infinilm::engine } // namespace infinilm::engine
...@@ -26,9 +26,11 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, ...@@ -26,9 +26,11 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
RankBarrier *barrier, RankBarrier *barrier,
bool enable_graph_compiling) bool enable_graph_compiling,
backends::AttentionBackend attention_backend)
: legacy_model_config_(model_config), : legacy_model_config_(model_config),
rank_info_(rank_info), rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling), enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT), job_cmd_(Command::INIT),
has_job_(false), has_job_(false),
...@@ -53,9 +55,11 @@ RankWorker::RankWorker( ...@@ -53,9 +55,11 @@ RankWorker::RankWorker(
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
RankBarrier *barrier, RankBarrier *barrier,
bool enable_graph_compiling) bool enable_graph_compiling,
backends::AttentionBackend attention_backend)
: model_config_(model_config), : model_config_(model_config),
rank_info_(rank_info), rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling), enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT), job_cmd_(Command::INIT),
has_job_(false), has_job_(false),
...@@ -234,10 +238,18 @@ void RankWorker::thread_loop() { ...@@ -234,10 +238,18 @@ void RankWorker::thread_loop() {
// Create model using factory (may be expensive) // Create model using factory (may be expensive)
if (model_config_ == nullptr) { if (model_config_ == nullptr) {
model_ = InfinilmModelFactory::createModel(legacy_model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); model_ = InfinilmModelFactory::createModel(
legacy_model_config_,
rank_info_,
pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr,
attention_backend_);
} else { } else {
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); model_ = InfinilmModelFactory::createModel(
model_config_,
rank_info_,
pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr,
attention_backend_);
} }
if (!model_) { if (!model_) {
...@@ -339,7 +351,7 @@ void RankWorker::thread_loop() { ...@@ -339,7 +351,7 @@ void RankWorker::thread_loop() {
const auto &batch_size{logits_shape[0]}; const auto &batch_size{logits_shape[0]};
auto n_req = local_args.input_offsets.value()->size(0) - 1; auto n_req = local_args.input_offsets.value()->size(0) - 1;
int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data(); int32_t *input_offsets = (int32_t *)local_args.input_offsets.value()->data();
auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)}; auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
......
#pragma once #pragma once
#include "../backends/attention_backends.hpp"
#include "../cache/cache.hpp" #include "../cache/cache.hpp"
#include "../config/model_config.hpp" #include "../config/model_config.hpp"
#include "../models/model_factory.hpp" #include "../models/model_factory.hpp"
...@@ -37,8 +38,10 @@ public: ...@@ -37,8 +38,10 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths; std::optional<infinicore::Tensor> past_sequence_lengths;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`. /// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std::optional<infinicore::Tensor> total_sequence_lengths; std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`. /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets; std::optional<infinicore::Tensor> input_offsets;
/// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> cu_seqlens;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables; std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache. /// Slot ids for each token `[seq]`. Used for paged cache.
...@@ -61,13 +64,15 @@ public: ...@@ -61,13 +64,15 @@ public:
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
RankBarrier *barrier, RankBarrier *barrier,
bool enable_graph_compiling); bool enable_graph_compiling,
backends::AttentionBackend attention_backend);
RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config, RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
RankBarrier *barrier, RankBarrier *barrier,
bool enable_graph_compiling); bool enable_graph_compiling,
backends::AttentionBackend attention_backend);
// Submit a parameter load job and wait until the load completes on the worker thread. // Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name, void load_param(const std::string &name,
...@@ -107,6 +112,9 @@ private: ...@@ -107,6 +112,9 @@ private:
std::shared_ptr<InfinilmModel> model_; std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::Cache> cache_; std::shared_ptr<cache::Cache> cache_;
// Backends
backends::AttentionBackend attention_backend_;
// Graph Compiling // Graph Compiling
bool enable_graph_compiling_; bool enable_graph_compiling_;
std::unique_ptr<GraphCompiler> compiler_; std::unique_ptr<GraphCompiler> compiler_;
......
...@@ -27,6 +27,8 @@ public: ...@@ -27,6 +27,8 @@ public:
std::optional<infinicore::Tensor> total_sequence_lengths; std::optional<infinicore::Tensor> total_sequence_lengths;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets; std::optional<infinicore::Tensor> input_offsets;
/// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> cu_seqlens;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std::optional<infinicore::Tensor> block_tables; std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache. /// Slot ids for each token `[seq]`. Used for paged cache.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp" #include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mul.hpp" #include "infinicore/ops/mul.hpp"
#include <algorithm> #include <algorithm>
...@@ -31,7 +32,8 @@ namespace infinilm::models::llama { ...@@ -31,7 +32,8 @@ namespace infinilm::models::llama {
LlamaAttention::LlamaAttention(const LlamaConfig &config, LlamaAttention::LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: layer_idx_(layer_idx), : layer_idx_(layer_idx),
hidden_size_(config.hidden_size), hidden_size_(config.hidden_size),
num_attention_heads_(config.num_attention_heads), num_attention_heads_(config.num_attention_heads),
...@@ -41,7 +43,9 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -41,7 +43,9 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
use_bias_(config.attention_bias), use_bias_(config.attention_bias),
use_output_bias_(config.attention_output_bias), use_output_bias_(config.attention_output_bias),
use_qk_norm_(config.qk_norm), use_qk_norm_(config.qk_norm),
max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) { max_position_embeddings_(config.max_position_embeddings),
rank_info_(rank_info),
attention_backend_(attention_backend) {
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
int tp_rank = rank_info.tp_rank; int tp_rank = rank_info.tp_rank;
...@@ -75,7 +79,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -75,7 +79,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: model_config_(model_config), : model_config_(model_config),
layer_idx_(layer_idx), layer_idx_(layer_idx),
hidden_size_(model_config->get<size_t>("hidden_size")), hidden_size_(model_config->get<size_t>("hidden_size")),
...@@ -86,7 +91,8 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo ...@@ -86,7 +91,8 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
use_bias_(model_config->get_or<bool>("attention_bias", true)), use_bias_(model_config->get_or<bool>("attention_bias", true)),
use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)), use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)),
max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")), max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")),
rank_info_(rank_info) { rank_info_(rank_info),
attention_backend_(attention_backend) {
const auto &dtype{model_config_->get_dtype()}; const auto &dtype{model_config_->get_dtype()};
int tp_rank = rank_info.tp_rank; int tp_rank = rank_info.tp_rank;
...@@ -203,7 +209,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -203,7 +209,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
->contiguous() ->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
} else { } else {
size_t total_seq_len = reinterpret_cast<int64_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; size_t total_seq_len = reinterpret_cast<int32_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
...@@ -238,6 +244,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -238,6 +244,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache, std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
ASSERT(block_tables.has_value()); ASSERT(block_tables.has_value());
...@@ -298,6 +305,20 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -298,6 +305,20 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
if (is_prefill) { if (is_prefill) {
if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
infinicore::op::mha_varlen_(
attn_output,
q_reshaped,
k_total->permute({0, 2, 1, 3}),
v_total->permute({0, 2, 1, 3}),
input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt,
scaling_);
} else {
infinicore::op::paged_attention_prefill_( infinicore::op::paged_attention_prefill_(
attn_output, attn_output,
q_reshaped, q_reshaped,
...@@ -308,7 +329,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -308,7 +329,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
input_offsets.value(), input_offsets.value(),
std::nullopt, std::nullopt,
scaling_); scaling_);
}
} else { } else {
infinicore::op::paged_attention_( infinicore::op::paged_attention_(
attn_output, attn_output,
...@@ -322,7 +343,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -322,7 +343,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
} }
// 7. Project output // 7. Project output
attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_}); attn_output
= attn_output->view({1, seq_len, num_attention_heads_ * head_dim_});
return o_proj_->forward(attn_output); return o_proj_->forward(attn_output);
} }
...@@ -332,6 +354,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -332,6 +354,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
if (!rotary_emb_) { if (!rotary_emb_) {
...@@ -340,7 +363,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -340,7 +363,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
infinicore::Tensor output; infinicore::Tensor output;
if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) { if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, block_tables, slot_mapping); output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
} else { } else {
output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths); output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths);
......
#pragma once #pragma once
#include "../../backends/attention_backends.hpp"
#include "../../cache/kv_cache.hpp" #include "../../cache/kv_cache.hpp"
#include "../../config/model_config.hpp" #include "../../config/model_config.hpp"
#include "../../engine/distributed/distributed.hpp" #include "../../engine/distributed/distributed.hpp"
...@@ -52,12 +53,14 @@ public: ...@@ -52,12 +53,14 @@ public:
LlamaAttention(const LlamaConfig &config, LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/** /**
* @brief Forward pass: compute attention * @brief Forward pass: compute attention
...@@ -73,6 +76,7 @@ public: ...@@ -73,6 +76,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
...@@ -104,6 +108,7 @@ private: ...@@ -104,6 +108,7 @@ private:
std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache, std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
...@@ -132,6 +137,8 @@ private: ...@@ -132,6 +137,8 @@ private:
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_; float scaling_;
backends::AttentionBackend attention_backend_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -19,7 +19,8 @@ namespace infinilm::models::llama { ...@@ -19,7 +19,8 @@ namespace infinilm::models::llama {
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx), rank_info_(rank_info) { engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) : layer_idx_(layer_idx), rank_info_(rank_info) {
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
// Initialize layer normalization layers // Initialize layer normalization layers
...@@ -29,14 +30,15 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, ...@@ -29,14 +30,15 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
dtype, device); dtype, device);
// Initialize attention and MLP modules // Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_); INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_, attention_backend);
INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_); INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
} }
LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) { engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()}; const auto &dtype{model_config_->get_dtype()};
// Initialize layer normalization layers // Initialize layer normalization layers
INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"), INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
...@@ -45,7 +47,7 @@ LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConf ...@@ -45,7 +47,7 @@ LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConf
dtype, device); dtype, device);
// Initialize attention and MLP modules // Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_); INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_, attention_backend);
INFINICORE_NN_MODULE_INIT(mlp, model_config_, device, rank_info_); INFINICORE_NN_MODULE_INIT(mlp, model_config_, device, rank_info_);
} }
...@@ -57,13 +59,15 @@ LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states, ...@@ -57,13 +59,15 @@ LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states,
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Attention layer normalization // 1. Attention layer normalization
input_layernorm_->forward_inplace(hidden_states, residual); input_layernorm_->forward_inplace(hidden_states, residual);
// 2. Self-attention // 2. Self-attention
hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping); hidden_states = self_attn_->forward(
hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping);
// 3. Post-attention layer normalization // 3. Post-attention layer normalization
post_attention_layernorm_->forward_inplace(hidden_states, residual); post_attention_layernorm_->forward_inplace(hidden_states, residual);
......
...@@ -48,12 +48,14 @@ public: ...@@ -48,12 +48,14 @@ public:
LlamaDecoderLayer(const LlamaConfig &config, LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/** /**
* @brief Forward pass: process one decoder layer * @brief Forward pass: process one decoder layer
...@@ -73,6 +75,7 @@ public: ...@@ -73,6 +75,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mappin) const; std::optional<infinicore::Tensor> slot_mappin) const;
......
...@@ -17,13 +17,14 @@ namespace infinilm::models::llama { ...@@ -17,13 +17,14 @@ namespace infinilm::models::llama {
*/ */
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) { engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) {
// Initialize module's device_ member // Initialize module's device_ member
device_ = device; device_ = device;
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
// Initialize base model // Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, rank_info); INFINICORE_NN_MODULE_INIT(model, config, device, rank_info, attention_backend);
// Initialize language modeling head // Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
...@@ -34,14 +35,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, ...@@ -34,14 +35,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) { engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) {
// Initialize module's device_ member // Initialize module's device_ member
device_ = device; device_ = device;
const auto &dtype{model_config->get_dtype()}; const auto &dtype{model_config->get_dtype()};
// Initialize base model // Initialize base model
INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info); INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info, attention_backend);
// Initialize language modeling head // Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
// For now, we create a separate linear layer // For now, we create a separate linear layer
...@@ -56,12 +58,13 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { ...@@ -56,12 +58,13 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto past_sequence_lengths = input.past_sequence_lengths; auto past_sequence_lengths = input.past_sequence_lengths;
auto total_sequence_length = input.total_sequence_lengths; auto total_sequence_length = input.total_sequence_lengths;
auto input_offsets = input.input_offsets; auto input_offsets = input.input_offsets;
auto cu_seqlens = input.cu_seqlens;
auto block_tables = input.block_tables; auto block_tables = input.block_tables;
auto slot_mapping = input.slot_mapping; auto slot_mapping = input.slot_mapping;
// 1. Forward through base model to get hidden states // 1. Forward through base model to get hidden states
auto hidden_states = model_->forward( auto hidden_states = model_->forward(
input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping); input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, cu_seqlens, block_tables, slot_mapping);
// 2. Apply language modeling head to get logits // 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states); auto logits = lm_head_->forward(hidden_states);
......
...@@ -42,11 +42,13 @@ public: ...@@ -42,11 +42,13 @@ public:
*/ */
LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/** /**
* @brief Forward pass: compute language modeling logits * @brief Forward pass: compute language modeling logits
......
...@@ -20,7 +20,8 @@ namespace infinilm::models::llama { ...@@ -20,7 +20,8 @@ namespace infinilm::models::llama {
*/ */
LlamaModel::LlamaModel(const LlamaConfig &config, LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: config_(config), rank_info_(rank_info) { : config_(config), rank_info_(rank_info) {
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
// Initialize token embeddings // Initialize token embeddings
...@@ -34,7 +35,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -34,7 +35,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
layers_.reserve(config.num_hidden_layers); layers_.reserve(config.num_hidden_layers);
for (size_t i = 0; i < config.num_hidden_layers; ++i) { for (size_t i = 0; i < config.num_hidden_layers; ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>( layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), config, device, i, rank_info)); "layers." + std::to_string(i), config, device, i, rank_info, attention_backend));
} }
// Initialize final layer normalization // Initialize final layer normalization
...@@ -56,7 +57,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -56,7 +57,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: model_config_(model_config), rank_info_(rank_info) { : model_config_(model_config), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()}; const auto &dtype{model_config_->get_dtype()};
// Initialize token embeddings // Initialize token embeddings
...@@ -69,7 +71,7 @@ LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_conf ...@@ -69,7 +71,7 @@ LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_conf
layers_.reserve(model_config_->get<size_t>("num_hidden_layers")); layers_.reserve(model_config_->get<size_t>("num_hidden_layers"));
for (size_t i = 0; i < model_config_->get<size_t>("num_hidden_layers"); ++i) { for (size_t i = 0; i < model_config_->get<size_t>("num_hidden_layers"); ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>( layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), model_config_, device, i, rank_info)); "layers." + std::to_string(i), model_config_, device, i, rank_info, attention_backend));
} }
// Initialize final layer normalization // Initialize final layer normalization
INFINICORE_NN_MODULE_INIT(norm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"), INFINICORE_NN_MODULE_INIT(norm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
...@@ -92,6 +94,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -92,6 +94,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
...@@ -109,6 +112,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -109,6 +112,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
past_sequence_lengths, past_sequence_lengths,
total_sequence_lengths, total_sequence_lengths,
input_offsets, input_offsets,
cu_seqlens,
block_tables, block_tables,
slot_mapping); slot_mapping);
} }
......
...@@ -51,11 +51,13 @@ public: ...@@ -51,11 +51,13 @@ public:
*/ */
LlamaModel(const LlamaConfig &config, LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/** /**
* @brief Forward pass: process input through the model * @brief Forward pass: process input through the model
...@@ -73,6 +75,7 @@ public: ...@@ -73,6 +75,7 @@ public:
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const; std::optional<infinicore::Tensor> slot_mapping) const;
......
...@@ -17,12 +17,13 @@ namespace infinilm { ...@@ -17,12 +17,13 @@ namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info, engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) { const cache::CacheConfig *cache,
backends::AttentionBackend attention_backend) {
std::shared_ptr<InfinilmModel> model; std::shared_ptr<InfinilmModel> model;
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) { if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr; const auto &llama_config = *llama_config_ptr;
model = std::make_shared<models::llama::LlamaForCausalLM>( model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, rank_info); llama_config, rank_info.device, rank_info, attention_backend);
} else { } else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
} }
...@@ -37,12 +38,13 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( ...@@ -37,12 +38,13 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config, std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info, engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) { const cache::CacheConfig *cache,
backends::AttentionBackend attention_backend) {
std::shared_ptr<InfinilmModel> model; std::shared_ptr<InfinilmModel> model;
if (true) { if (true) {
model = std::make_shared<models::llama::LlamaForCausalLM>( model = std::make_shared<models::llama::LlamaForCausalLM>(
model_config, rank_info.device, rank_info); model_config, rank_info.device, rank_info, attention_backend);
} else { } else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../config/model_config.hpp" #include "../config/model_config.hpp"
#include "infinilm_model.hpp" #include "infinilm_model.hpp"
#include "../backends/attention_backends.hpp"
#include "../engine/distributed/distributed.hpp" #include "../engine/distributed/distributed.hpp"
namespace infinilm { namespace infinilm {
...@@ -23,11 +24,13 @@ public: ...@@ -23,11 +24,13 @@ public:
static std::shared_ptr<InfinilmModel> createModel( static std::shared_ptr<InfinilmModel> createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr); const cache::CacheConfig *cache = nullptr,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
static std::shared_ptr<InfinilmModel> createModel( static std::shared_ptr<InfinilmModel> createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config, std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr); const cache::CacheConfig *cache = nullptr,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
}; };
} // namespace infinilm } // namespace infinilm
...@@ -36,19 +36,22 @@ inline void bind_infer_engine(py::module &m) { ...@@ -36,19 +36,22 @@ inline void bind_infer_engine(py::module &m) {
const distributed::DistConfig &dist, const distributed::DistConfig &dist,
infinicore::Device::Type dev, infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg, std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) { bool enable_graph_compiling,
const std::string &attention_backend) {
return std::make_shared<InferEngine>( return std::make_shared<InferEngine>(
cfg, cfg,
dist, dist,
dev, dev,
cache_cfg ? cache_cfg.get() : nullptr, cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling); enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend));
}), }),
py::arg("config"), py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(), py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(), py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false) py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default")
.def("load_param", &InferEngine::load_param, .def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"), py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)") "Load a parameter tensor into all workers (each worker picks its shard)")
...@@ -63,11 +66,14 @@ inline void bind_infer_engine(py::module &m) { ...@@ -63,11 +66,14 @@ inline void bind_infer_engine(py::module &m) {
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") .def(
.def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("get_cache_config", [](const InferEngine &self) { .def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
auto cfg = self.get_cache_config(); auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); }) return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr;
})
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; }); .def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
infer_engine infer_engine
...@@ -76,19 +82,22 @@ inline void bind_infer_engine(py::module &m) { ...@@ -76,19 +82,22 @@ inline void bind_infer_engine(py::module &m) {
const distributed::DistConfig &dist, const distributed::DistConfig &dist,
infinicore::Device::Type dev, infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg, std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) { bool enable_graph_compiling,
const std::string &attention_backend) {
return std::make_shared<InferEngine>( return std::make_shared<InferEngine>(
model_path, model_path,
dist, dist,
dev, dev,
cache_cfg ? cache_cfg.get() : nullptr, cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling); enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend));
}), }),
py::arg("model_path") = "", py::arg("model_path") = "",
py::arg("distributed_config") = distributed::DistConfig(), py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(), py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false) py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default")
.def("load_param", &InferEngine::load_param, .def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"), py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)") "Load a parameter tensor into all workers (each worker picks its shard)")
...@@ -103,8 +112,10 @@ inline void bind_infer_engine(py::module &m) { ...@@ -103,8 +112,10 @@ inline void bind_infer_engine(py::module &m) {
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") .def(
.def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) { .def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config(); auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); }) return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
...@@ -118,6 +129,7 @@ inline void bind_infer_engine(py::module &m) { ...@@ -118,6 +129,7 @@ inline void bind_infer_engine(py::module &m) {
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
std::optional<infinicore::Tensor> total_sequence_lengths, std::optional<infinicore::Tensor> total_sequence_lengths,
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> cu_seqlens,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping, std::optional<infinicore::Tensor> slot_mapping,
py::kwargs kwargs) { py::kwargs kwargs) {
...@@ -127,6 +139,7 @@ inline void bind_infer_engine(py::module &m) { ...@@ -127,6 +139,7 @@ inline void bind_infer_engine(py::module &m) {
std::move(past_sequence_lengths), std::move(past_sequence_lengths),
std::move(total_sequence_lengths), std::move(total_sequence_lengths),
std::move(input_offsets), std::move(input_offsets),
std::move(cu_seqlens),
std::move(block_tables), std::move(block_tables),
std::move(slot_mapping), std::move(slot_mapping),
}; };
...@@ -167,6 +180,7 @@ inline void bind_infer_engine(py::module &m) { ...@@ -167,6 +180,7 @@ inline void bind_infer_engine(py::module &m) {
py::arg("past_sequence_lengths") = std::nullopt, py::arg("past_sequence_lengths") = std::nullopt,
py::arg("total_sequence_lengths") = std::nullopt, py::arg("total_sequence_lengths") = std::nullopt,
py::arg("input_offsets") = std::nullopt, py::arg("input_offsets") = std::nullopt,
py::arg("cu_seqlens") = std::nullopt,
py::arg("block_tables") = std::nullopt, py::arg("block_tables") = std::nullopt,
py::arg("slot_mapping") = std::nullopt) py::arg("slot_mapping") = std::nullopt)
.def_readwrite("input_ids", &InferEngine::Input::input_ids) .def_readwrite("input_ids", &InferEngine::Input::input_ids)
...@@ -174,6 +188,7 @@ inline void bind_infer_engine(py::module &m) { ...@@ -174,6 +188,7 @@ inline void bind_infer_engine(py::module &m) {
.def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths) .def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
.def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths) .def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
.def_readwrite("input_offsets", &InferEngine::Input::input_offsets) .def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
.def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens)
.def_readwrite("block_tables", &InferEngine::Input::block_tables) .def_readwrite("block_tables", &InferEngine::Input::block_tables)
.def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping) .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
.def_readwrite("temperature", &InferEngine::Input::temperature) .def_readwrite("temperature", &InferEngine::Input::temperature)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment