Commit ae210024 authored by PanZezhong's avatar PanZezhong
Browse files

issue/248 add arg for flash-attn backend

parent 7668db4f
#pragma once
#include <stdexcept>
#include <string>
namespace infinilm::backends {
enum class AttentionBackend {
Default,
FlashAttn,
};
inline AttentionBackend parse_attention_backend(const std::string &backend) {
if (backend == "default") {
return AttentionBackend::Default;
}
if (backend == "flash-attn") {
return AttentionBackend::FlashAttn;
}
throw std::invalid_argument(
"Invalid attention_backend: " + backend + ". Valid options are: default, flash-attn");
}
} // namespace infinilm::backends
...@@ -23,9 +23,11 @@ InferEngine::InferEngine( ...@@ -23,9 +23,11 @@ InferEngine::InferEngine(
const distributed::DistConfig &distributed_config, const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type, infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter bool enable_graph_compiling,
backends::AttentionBackend attention_backend) // Changed parameter
: communication_group_(distributed_config, device_type), : communication_group_(distributed_config, device_type),
legacy_model_config_(config) { legacy_model_config_(config),
attention_backend_(attention_backend) {
if (cache_config != nullptr) { if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy(); cache_config_ = cache_config->unique_copy();
} }
...@@ -39,7 +41,8 @@ InferEngine::InferEngine( ...@@ -39,7 +41,8 @@ InferEngine::InferEngine(
communication_group_.get_rank_info(r), communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr, cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(), barrier_.get(),
enable_graph_compiling)); enable_graph_compiling,
attention_backend_));
} }
// Compile the model on all workers // Compile the model on all workers
...@@ -51,8 +54,9 @@ InferEngine::InferEngine( ...@@ -51,8 +54,9 @@ InferEngine::InferEngine(
const distributed::DistConfig &distributed_config, const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type, infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter bool enable_graph_compiling,
: communication_group_(distributed_config, device_type) { backends::AttentionBackend attention_backend) // Changed parameter
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
if (cache_config != nullptr) { if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy(); cache_config_ = cache_config->unique_copy();
} }
...@@ -69,7 +73,8 @@ InferEngine::InferEngine( ...@@ -69,7 +73,8 @@ InferEngine::InferEngine(
communication_group_.get_rank_info(r), communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr, cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(), barrier_.get(),
enable_graph_compiling)); enable_graph_compiling,
attention_backend_));
} }
// Compile the model on all workers // Compile the model on all workers
this->compile(); this->compile();
......
...@@ -37,14 +37,16 @@ public: ...@@ -37,14 +37,16 @@ public:
const distributed::DistConfig &distributed_config = distributed::DistConfig(), const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr, const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false); bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
InferEngine( InferEngine(
const std::string &model_path = "", const std::string &model_path = "",
const distributed::DistConfig &distributed_config = distributed::DistConfig(), const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr, const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false); bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
// Load a parameter to all workers (each can extract its shard inside RankWorker) // Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param); void load_param(const std::string &name, const infinicore::Tensor &param);
...@@ -73,6 +75,7 @@ protected: ...@@ -73,6 +75,7 @@ protected:
std::unique_ptr<cache::CacheConfig> cache_config_; std::unique_ptr<cache::CacheConfig> cache_config_;
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config(); const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::config::ModelConfig> model_config_; std::shared_ptr<infinilm::config::ModelConfig> model_config_;
backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default;
}; };
} // namespace infinilm::engine } // namespace infinilm::engine
...@@ -26,9 +26,11 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, ...@@ -26,9 +26,11 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
RankBarrier *barrier, RankBarrier *barrier,
bool enable_graph_compiling) bool enable_graph_compiling,
backends::AttentionBackend attention_backend)
: legacy_model_config_(model_config), : legacy_model_config_(model_config),
rank_info_(rank_info), rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling), enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT), job_cmd_(Command::INIT),
has_job_(false), has_job_(false),
...@@ -53,9 +55,11 @@ RankWorker::RankWorker( ...@@ -53,9 +55,11 @@ RankWorker::RankWorker(
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
RankBarrier *barrier, RankBarrier *barrier,
bool enable_graph_compiling) bool enable_graph_compiling,
backends::AttentionBackend attention_backend)
: model_config_(model_config), : model_config_(model_config),
rank_info_(rank_info), rank_info_(rank_info),
attention_backend_(attention_backend),
enable_graph_compiling_(enable_graph_compiling), enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT), job_cmd_(Command::INIT),
has_job_(false), has_job_(false),
...@@ -234,10 +238,18 @@ void RankWorker::thread_loop() { ...@@ -234,10 +238,18 @@ void RankWorker::thread_loop() {
// Create model using factory (may be expensive) // Create model using factory (may be expensive)
if (model_config_ == nullptr) { if (model_config_ == nullptr) {
model_ = InfinilmModelFactory::createModel(legacy_model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); model_ = InfinilmModelFactory::createModel(
legacy_model_config_,
rank_info_,
pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr,
attention_backend_);
} else { } else {
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); model_ = InfinilmModelFactory::createModel(
model_config_,
rank_info_,
pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr,
attention_backend_);
} }
if (!model_) { if (!model_) {
......
#pragma once #pragma once
#include "../backends/attention_backends.hpp"
#include "../cache/cache.hpp" #include "../cache/cache.hpp"
#include "../config/model_config.hpp" #include "../config/model_config.hpp"
#include "../models/model_factory.hpp" #include "../models/model_factory.hpp"
...@@ -63,13 +64,15 @@ public: ...@@ -63,13 +64,15 @@ public:
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
RankBarrier *barrier, RankBarrier *barrier,
bool enable_graph_compiling); bool enable_graph_compiling,
backends::AttentionBackend attention_backend);
RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config, RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config, const cache::CacheConfig *cache_config,
RankBarrier *barrier, RankBarrier *barrier,
bool enable_graph_compiling); bool enable_graph_compiling,
backends::AttentionBackend attention_backend);
// Submit a parameter load job and wait until the load completes on the worker thread. // Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name, void load_param(const std::string &name,
...@@ -109,6 +112,9 @@ private: ...@@ -109,6 +112,9 @@ private:
std::shared_ptr<InfinilmModel> model_; std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::Cache> cache_; std::shared_ptr<cache::Cache> cache_;
// Backends
backends::AttentionBackend attention_backend_;
// Graph Compiling // Graph Compiling
bool enable_graph_compiling_; bool enable_graph_compiling_;
std::unique_ptr<GraphCompiler> compiler_; std::unique_ptr<GraphCompiler> compiler_;
......
...@@ -32,7 +32,8 @@ namespace infinilm::models::llama { ...@@ -32,7 +32,8 @@ namespace infinilm::models::llama {
LlamaAttention::LlamaAttention(const LlamaConfig &config, LlamaAttention::LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: layer_idx_(layer_idx), : layer_idx_(layer_idx),
hidden_size_(config.hidden_size), hidden_size_(config.hidden_size),
num_attention_heads_(config.num_attention_heads), num_attention_heads_(config.num_attention_heads),
...@@ -42,7 +43,9 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -42,7 +43,9 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
use_bias_(config.attention_bias), use_bias_(config.attention_bias),
use_output_bias_(config.attention_output_bias), use_output_bias_(config.attention_output_bias),
use_qk_norm_(config.qk_norm), use_qk_norm_(config.qk_norm),
max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) { max_position_embeddings_(config.max_position_embeddings),
rank_info_(rank_info),
attention_backend_(attention_backend) {
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
int tp_rank = rank_info.tp_rank; int tp_rank = rank_info.tp_rank;
...@@ -76,7 +79,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -76,7 +79,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: model_config_(model_config), : model_config_(model_config),
layer_idx_(layer_idx), layer_idx_(layer_idx),
hidden_size_(model_config->get<size_t>("hidden_size")), hidden_size_(model_config->get<size_t>("hidden_size")),
...@@ -87,7 +91,8 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo ...@@ -87,7 +91,8 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
use_bias_(model_config->get_or<bool>("attention_bias", true)), use_bias_(model_config->get_or<bool>("attention_bias", true)),
use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)), use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)),
max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")), max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")),
rank_info_(rank_info) { rank_info_(rank_info),
attention_backend_(attention_backend) {
const auto &dtype{model_config_->get_dtype()}; const auto &dtype{model_config_->get_dtype()};
int tp_rank = rank_info.tp_rank; int tp_rank = rank_info.tp_rank;
...@@ -299,42 +304,44 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -299,42 +304,44 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
// 6. Compute attention // 6. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
// if (is_prefill) { if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
// infinicore::op::paged_attention_prefill_( infinicore::op::mha_varlen_(
// attn_output, attn_output,
// q_reshaped, q_reshaped,
// k_total, k_total->permute({0, 2, 1, 3}),
// v_total, v_total->permute({0, 2, 1, 3}),
// block_tables.value(), input_offsets.value(),
// total_sequence_lengths.value(), cu_seqlens.value(),
// input_offsets.value(), block_tables.value(),
// std::nullopt, max_position_embeddings_,
// scaling_); max_position_embeddings_,
std::nullopt,
// } else { scaling_);
// infinicore::op::paged_attention_( } else {
// attn_output, if (is_prefill) {
// q_reshaped, infinicore::op::paged_attention_prefill_(
// k_total, attn_output,
// v_total, q_reshaped,
// block_tables.value(), k_total,
// total_sequence_lengths.value(), v_total,
// std::nullopt, block_tables.value(),
// scaling_); total_sequence_lengths.value(),
// } input_offsets.value(),
std::nullopt,
infinicore::op::mha_varlen_( scaling_);
attn_output,
q_reshaped, } else {
k_total->permute({0, 2, 1, 3}), infinicore::op::paged_attention_(
v_total->permute({0, 2, 1, 3}), attn_output,
input_offsets.value(), q_reshaped,
cu_seqlens.value(), k_total,
block_tables.value(), v_total,
max_position_embeddings_, block_tables.value(),
max_position_embeddings_, total_sequence_lengths.value(),
std::nullopt, std::nullopt,
scaling_); scaling_);
}
}
// 7. Project output // 7. Project output
attn_output attn_output
......
#pragma once #pragma once
#include "../../backends/attention_backends.hpp"
#include "../../cache/kv_cache.hpp" #include "../../cache/kv_cache.hpp"
#include "../../config/model_config.hpp" #include "../../config/model_config.hpp"
#include "../../engine/distributed/distributed.hpp" #include "../../engine/distributed/distributed.hpp"
...@@ -52,12 +53,14 @@ public: ...@@ -52,12 +53,14 @@ public:
LlamaAttention(const LlamaConfig &config, LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/** /**
* @brief Forward pass: compute attention * @brief Forward pass: compute attention
...@@ -134,6 +137,8 @@ private: ...@@ -134,6 +137,8 @@ private:
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_; float scaling_;
backends::AttentionBackend attention_backend_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -19,7 +19,8 @@ namespace infinilm::models::llama { ...@@ -19,7 +19,8 @@ namespace infinilm::models::llama {
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx), rank_info_(rank_info) { engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) : layer_idx_(layer_idx), rank_info_(rank_info) {
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
// Initialize layer normalization layers // Initialize layer normalization layers
...@@ -29,14 +30,15 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, ...@@ -29,14 +30,15 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
dtype, device); dtype, device);
// Initialize attention and MLP modules // Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_); INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_, attention_backend);
INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_); INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
} }
LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) { engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()}; const auto &dtype{model_config_->get_dtype()};
// Initialize layer normalization layers // Initialize layer normalization layers
INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"), INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
...@@ -45,7 +47,7 @@ LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConf ...@@ -45,7 +47,7 @@ LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConf
dtype, device); dtype, device);
// Initialize attention and MLP modules // Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_); INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_, attention_backend);
INFINICORE_NN_MODULE_INIT(mlp, model_config_, device, rank_info_); INFINICORE_NN_MODULE_INIT(mlp, model_config_, device, rank_info_);
} }
......
...@@ -48,12 +48,14 @@ public: ...@@ -48,12 +48,14 @@ public:
LlamaDecoderLayer(const LlamaConfig &config, LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/** /**
* @brief Forward pass: process one decoder layer * @brief Forward pass: process one decoder layer
......
...@@ -17,13 +17,14 @@ namespace infinilm::models::llama { ...@@ -17,13 +17,14 @@ namespace infinilm::models::llama {
*/ */
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) { engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) {
// Initialize module's device_ member // Initialize module's device_ member
device_ = device; device_ = device;
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
// Initialize base model // Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, rank_info); INFINICORE_NN_MODULE_INIT(model, config, device, rank_info, attention_backend);
// Initialize language modeling head // Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
...@@ -34,14 +35,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, ...@@ -34,14 +35,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) { engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend) {
// Initialize module's device_ member // Initialize module's device_ member
device_ = device; device_ = device;
const auto &dtype{model_config->get_dtype()}; const auto &dtype{model_config->get_dtype()};
// Initialize base model // Initialize base model
INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info); INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info, attention_backend);
// Initialize language modeling head // Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
// For now, we create a separate linear layer // For now, we create a separate linear layer
......
...@@ -42,11 +42,13 @@ public: ...@@ -42,11 +42,13 @@ public:
*/ */
LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/** /**
* @brief Forward pass: compute language modeling logits * @brief Forward pass: compute language modeling logits
......
...@@ -20,7 +20,8 @@ namespace infinilm::models::llama { ...@@ -20,7 +20,8 @@ namespace infinilm::models::llama {
*/ */
LlamaModel::LlamaModel(const LlamaConfig &config, LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: config_(config), rank_info_(rank_info) { : config_(config), rank_info_(rank_info) {
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
// Initialize token embeddings // Initialize token embeddings
...@@ -34,7 +35,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -34,7 +35,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
layers_.reserve(config.num_hidden_layers); layers_.reserve(config.num_hidden_layers);
for (size_t i = 0; i < config.num_hidden_layers; ++i) { for (size_t i = 0; i < config.num_hidden_layers; ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>( layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), config, device, i, rank_info)); "layers." + std::to_string(i), config, device, i, rank_info, attention_backend));
} }
// Initialize final layer normalization // Initialize final layer normalization
...@@ -56,7 +57,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -56,7 +57,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info,
backends::AttentionBackend attention_backend)
: model_config_(model_config), rank_info_(rank_info) { : model_config_(model_config), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()}; const auto &dtype{model_config_->get_dtype()};
// Initialize token embeddings // Initialize token embeddings
...@@ -69,7 +71,7 @@ LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_conf ...@@ -69,7 +71,7 @@ LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_conf
layers_.reserve(model_config_->get<size_t>("num_hidden_layers")); layers_.reserve(model_config_->get<size_t>("num_hidden_layers"));
for (size_t i = 0; i < model_config_->get<size_t>("num_hidden_layers"); ++i) { for (size_t i = 0; i < model_config_->get<size_t>("num_hidden_layers"); ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>( layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), model_config_, device, i, rank_info)); "layers." + std::to_string(i), model_config_, device, i, rank_info, attention_backend));
} }
// Initialize final layer normalization // Initialize final layer normalization
INFINICORE_NN_MODULE_INIT(norm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"), INFINICORE_NN_MODULE_INIT(norm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
......
...@@ -51,11 +51,13 @@ public: ...@@ -51,11 +51,13 @@ public:
*/ */
LlamaModel(const LlamaConfig &config, LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config, LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
/** /**
* @brief Forward pass: process input through the model * @brief Forward pass: process input through the model
......
...@@ -17,12 +17,13 @@ namespace infinilm { ...@@ -17,12 +17,13 @@ namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info, engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) { const cache::CacheConfig *cache,
backends::AttentionBackend attention_backend) {
std::shared_ptr<InfinilmModel> model; std::shared_ptr<InfinilmModel> model;
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) { if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr; const auto &llama_config = *llama_config_ptr;
model = std::make_shared<models::llama::LlamaForCausalLM>( model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, rank_info); llama_config, rank_info.device, rank_info, attention_backend);
} else { } else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
} }
...@@ -37,12 +38,13 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( ...@@ -37,12 +38,13 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config, std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info, engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) { const cache::CacheConfig *cache,
backends::AttentionBackend attention_backend) {
std::shared_ptr<InfinilmModel> model; std::shared_ptr<InfinilmModel> model;
if (true) { if (true) {
model = std::make_shared<models::llama::LlamaForCausalLM>( model = std::make_shared<models::llama::LlamaForCausalLM>(
model_config, rank_info.device, rank_info); model_config, rank_info.device, rank_info, attention_backend);
} else { } else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../config/model_config.hpp" #include "../config/model_config.hpp"
#include "infinilm_model.hpp" #include "infinilm_model.hpp"
#include "../backends/attention_backends.hpp"
#include "../engine/distributed/distributed.hpp" #include "../engine/distributed/distributed.hpp"
namespace infinilm { namespace infinilm {
...@@ -23,11 +24,13 @@ public: ...@@ -23,11 +24,13 @@ public:
static std::shared_ptr<InfinilmModel> createModel( static std::shared_ptr<InfinilmModel> createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr); const cache::CacheConfig *cache = nullptr,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
static std::shared_ptr<InfinilmModel> createModel( static std::shared_ptr<InfinilmModel> createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config, std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr); const cache::CacheConfig *cache = nullptr,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
}; };
} // namespace infinilm } // namespace infinilm
...@@ -36,19 +36,22 @@ inline void bind_infer_engine(py::module &m) { ...@@ -36,19 +36,22 @@ inline void bind_infer_engine(py::module &m) {
const distributed::DistConfig &dist, const distributed::DistConfig &dist,
infinicore::Device::Type dev, infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg, std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) { bool enable_graph_compiling,
const std::string &attention_backend) {
return std::make_shared<InferEngine>( return std::make_shared<InferEngine>(
cfg, cfg,
dist, dist,
dev, dev,
cache_cfg ? cache_cfg.get() : nullptr, cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling); enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend));
}), }),
py::arg("config"), py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(), py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(), py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false) py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default")
.def("load_param", &InferEngine::load_param, .def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"), py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)") "Load a parameter tensor into all workers (each worker picks its shard)")
...@@ -63,8 +66,10 @@ inline void bind_infer_engine(py::module &m) { ...@@ -63,8 +66,10 @@ inline void bind_infer_engine(py::module &m) {
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") .def(
.def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) { .def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config(); auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); }) return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
...@@ -76,19 +81,22 @@ inline void bind_infer_engine(py::module &m) { ...@@ -76,19 +81,22 @@ inline void bind_infer_engine(py::module &m) {
const distributed::DistConfig &dist, const distributed::DistConfig &dist,
infinicore::Device::Type dev, infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg, std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) { bool enable_graph_compiling,
const std::string &attention_backend) {
return std::make_shared<InferEngine>( return std::make_shared<InferEngine>(
model_path, model_path,
dist, dist,
dev, dev,
cache_cfg ? cache_cfg.get() : nullptr, cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling); enable_graph_compiling,
infinilm::backends::parse_attention_backend(attention_backend));
}), }),
py::arg("model_path") = "", py::arg("model_path") = "",
py::arg("distributed_config") = distributed::DistConfig(), py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(), py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false) py::arg("enable_graph_compiling") = false,
py::arg("attention_backend") = "default")
.def("load_param", &InferEngine::load_param, .def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"), py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)") "Load a parameter tensor into all workers (each worker picks its shard)")
...@@ -103,8 +111,10 @@ inline void bind_infer_engine(py::module &m) { ...@@ -103,8 +111,10 @@ inline void bind_infer_engine(py::module &m) {
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") .def(
.def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) { .def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config(); auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); }) return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
......
...@@ -29,6 +29,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -29,6 +29,7 @@ class InferEngine(_infinilm.InferEngine):
distributed_config=DistConfig(1), distributed_config=DistConfig(1),
cache_config=None, cache_config=None,
enable_graph_compiling=False, enable_graph_compiling=False,
attention_backend="flash-attn",
): ):
self.config = AutoConfig.from_pretrained(model_path) self.config = AutoConfig.from_pretrained(model_path)
...@@ -41,6 +42,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -41,6 +42,7 @@ class InferEngine(_infinilm.InferEngine):
device._underlying.type, device._underlying.type,
cache_config, cache_config,
enable_graph_compiling, enable_graph_compiling,
attention_backend,
) )
self.use_cache = False self.use_cache = False
...@@ -197,7 +199,8 @@ class InferEngine(_infinilm.InferEngine): ...@@ -197,7 +199,8 @@ class InferEngine(_infinilm.InferEngine):
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64 [past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
) )
cu_seqlens = infinicore.from_list( cu_seqlens = infinicore.from_list(
[(past_seq_len + seq_len) * i for i in range(batch_size + 1)], dtype=infinicore.int32 [(past_seq_len + seq_len) * i for i in range(batch_size + 1)],
dtype=infinicore.int32,
) )
input_offsets = infinicore.from_list( input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32 [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32
...@@ -209,7 +212,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -209,7 +212,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=past_kv_lengths, past_kv_lengths=past_kv_lengths,
total_kv_lengths=total_kv_lengths, total_kv_lengths=total_kv_lengths,
input_offsets=input_offsets, input_offsets=input_offsets,
cu_seqlens = cu_seqlens, cu_seqlens=cu_seqlens,
block_tables=block_tables, block_tables=block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
temperature=generation_config.temperature, temperature=generation_config.temperature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment