Unverified Commit a4ced800 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #205 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 96ecf490 04c37f3f
...@@ -29,3 +29,5 @@ __pycache__/ ...@@ -29,3 +29,5 @@ __pycache__/
*.txt *.txt
*.http *.http
*.nsys-rep
[submodule "third_party/spdlog"] [submodule "third_party/spdlog"]
path = third_party/spdlog path = third_party/spdlog
url = https://github.com/gabime/spdlog.git url = https://github.com/gabime/spdlog.git
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json.git
...@@ -71,7 +71,7 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA ...@@ -71,7 +71,7 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
- 单次推理测试 - 单次推理测试
- llama示例 - llama示例
```bash ```bash
python examples/llama.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir> python examples/llama.py [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --ali] --model_path=<path/to/model_dir>
``` ```
- 例如: - 例如:
```bash ```bash
......
...@@ -85,26 +85,36 @@ StaticKVCache::update(size_t layer_idx, ...@@ -85,26 +85,36 @@ StaticKVCache::update(size_t layer_idx,
auto batch_size = k->size(0); auto batch_size = k->size(0);
auto update_len = k->size(2); auto update_len = k->size(2);
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
ASSERT_EQ(batch_size, rank_batch_size_); ASSERT_EQ(batch_size, rank_batch_size_);
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}}); auto device = k_cache_layer->device();
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
if (device.getType() == infinicore::Device::Type::NVIDIA
k_cache_update->copy_from(k); || device.getType() == infinicore::Device::Type::ILUVATAR
v_cache_update->copy_from(v); || device.getType() == infinicore::Device::Type::METAX) {
infinicore::op::kv_caching_(
auto k_total = k_cache_layer->narrow({{2, 0, result_len}}); k_cache_layer,
auto v_total = v_cache_layer->narrow({{2, 0, result_len}}); v_cache_layer,
k,
v,
past_sequence_lengths);
} else {
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
k_cache_update->copy_from(k);
v_cache_update->copy_from(v);
}
return {k_total, v_total}; return {k_cache_layer, v_cache_layer};
} }
// ========================== // ==========================
......
#include "model_config.hpp"
namespace infinilm::config {
ModelConfig::ModelConfig(const std::string &path) {
std::ifstream file(path);
if (file.is_open()) {
file >> config_json;
file.close();
} else {
throw std::runtime_error("Could not open config file: " + path);
}
this->quant_config = QuantConfig(config_json["quantization_config"]);
}
infinicore::quantization::QuantScheme
ModelConfig::get_quant_scheme() const {
if (quant_config.get_quant_scheme() != infinicore::quantization::QuantScheme::NONE) {
return quant_config.get_quant_scheme();
} else {
return infinicore::quantization::QuantScheme::NONE;
}
}
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
ModelConfig::get_rope_scaling() const {
if (!config_json.contains("rope_scaling") || config_json["rope_scaling"].is_null()) {
return nullptr;
}
const auto &rope_scaling = config_json["rope_scaling"];
if (!rope_scaling.is_object()) {
throw std::runtime_error("rope_scaling must be an object");
}
if (!rope_scaling.contains("type")) {
throw std::runtime_error("rope_scaling must contain 'type' field");
}
std::string type_str = rope_scaling["type"].get<std::string>();
if (type_str == "longrope") {
// Required fields for LongRopeConfig
if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
throw std::runtime_error(
"LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'");
}
auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();
float factor = 1.0f;
if (rope_scaling.contains("factor")) {
factor = rope_scaling["factor"].get<float>();
}
return std::make_shared<infinicore::nn::RoPE::LongRopeConfig>(
std::move(short_factor),
std::move(long_factor),
original_max_position_embeddings,
factor);
} else if (type_str == "default" || type_str == "none") {
// Default scaling, no scaling applied
return nullptr;
} else {
throw std::runtime_error("Unsupported rope_scaling type: " + type_str);
}
}
infinicore::DataType
ModelConfig::get_dtype() const {
try {
std::string dtype_str = this->get<std::string>("torch_dtype");
if (dtype_str == "float32") {
return infinicore::DataType::F32;
} else if (dtype_str == "float16") {
return infinicore::DataType::F16;
} else if (dtype_str == "bfloat16") {
return infinicore::DataType::BF16;
} else if (dtype_str == "int8") {
return infinicore::DataType::I8;
} else {
throw std::runtime_error("Unsupported dtype string: " + dtype_str);
}
} catch (const std::exception &e) {
throw std::runtime_error("Error getting dtype from config: " + std::string(e.what()));
}
}
} // namespace infinilm::config
#pragma once
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "quant_config.hpp"
#include <fstream>
#include <string>
namespace infinilm::config {
class ModelConfig {
// Model config is implemented using nlohmann/json and is primarily used for advanced configuration
// beyond the standard model config. It is initialized via ModelConfig(const std::string& path)
// and passed through the InferEngine during inference.
public:
ModelConfig() = default;
// Not Implemented
// ModelConfig(const nlohmann::json &json) : config_json(json) {};
ModelConfig(const std::string &path);
// Template Function to get a value by key with type safety
template <typename T>
T get(const std::string &key) const {
if (!config_json.contains(key)) {
throw std::out_of_range("Key '" + key + "' not found in config.");
}
try {
return config_json.at(key).get<T>();
} catch (const nlohmann::json::type_error &e) {
throw std::runtime_error("Type conversion failed for key '" + key + "': " + std::string(e.what()));
}
}
template <typename T>
T get_or(const std::string &key, const T &default_value) const {
if (!config_json.contains(key) || config_json.at(key).is_null()) {
return default_value;
}
try {
return config_json.at(key).get<T>();
} catch (const nlohmann::json::type_error &) {
// If type conversion fails, return default value
return default_value;
}
}
size_t get_kv_dim() const {
return get<size_t>("hidden_size") * get<size_t>("num_key_value_heads") / get<size_t>("num_attention_heads");
}
size_t get_head_dim() const {
if (config_json.contains("head_dim")) {
return get<size_t>("head_dim");
}
return get<size_t>("hidden_size") / get<size_t>("num_attention_heads");
}
QuantConfig get_quant_config() const {
return quant_config;
}
std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization_method() const {
return quant_config.get_quantization_method();
}
infinicore::DataType get_dtype() const;
infinicore::quantization::QuantScheme get_quant_scheme() const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;
private:
nlohmann::json config_json;
QuantConfig quant_config;
};
} // namespace infinilm::config
#include "quant_config.hpp"
namespace infinilm::config {
QuantConfig::QuantConfig(const nlohmann::json &json) : quantization_config(json) {
this->quantization_method = get_quantization_method();
}
std::shared_ptr<infinicore::quantization::BaseQuantization>
QuantConfig::get_quantization_method() const {
if (quantization_config.is_null()) {
// return nullptr;
return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config); // Default case if no matching scheme
}
// Determine the quantization scheme from the JSON config
if (quantization_config["quant_method"] == "compressed-tensors") {
return std::make_shared<infinicore::quantization::CompressedTensors>(quantization_config);
} else if (quantization_config["quant_method"] == "awq") {
return std::make_shared<infinicore::quantization::AWQ>(quantization_config);
} else {
return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config);
}
// Add other schemes as needed
return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config); // Default case if no matching scheme
}
} // namespace infinilm::config
#pragma once
// #include "../quantization/quantization.hpp"
#include "infinicore/quantization.hpp"
#include "nlohmann/json.hpp"
namespace infinilm::config {
class QuantConfig {
// QuantConfig is used to store and parse the "quantization" field from config.json.
// This is currently a basic version and will be extended in the future.
public:
QuantConfig() = default;
QuantConfig(const nlohmann::json &json);
std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization_method() const;
infinicore::quantization::QuantScheme get_quant_scheme() const {
if (quantization_method != nullptr) {
return quantization_method->get_quant_scheme();
} else {
return infinicore::quantization::QuantScheme::NONE;
}
}
private:
nlohmann::json quantization_config;
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_method;
};
} // namespace infinilm::config
#include "general_compiler.hpp"
namespace infinilm::engine {
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : GraphCompiler(model, barrier) {
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_, barrier);
paged_compiler_ = std::make_unique<PagedCompiler>(model_, barrier);
}
void GeneralCompiler::compile() {
static_batching_compiler_->compile();
paged_compiler_->compile();
}
GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Input &input) {
GeneralCompiler::Compiled result = {nullptr, nullptr};
// try each compiler, return the first valid result
result = static_batching_compiler_.get()->get_compiled(input);
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
return result;
}
result = paged_compiler_.get()->get_compiled(input);
return result;
}
} // namespace infinilm::engine
#pragma once
#include "paged_compiler.hpp"
#include "static_batching_compiler.hpp"
namespace infinilm::engine {
class GeneralCompiler : public GraphCompiler {
public:
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
void compile() override;
Compiled get_compiled(const InfinilmModel::Input &input) override;
private:
std::unique_ptr<StaticBatchingCompiler> static_batching_compiler_;
std::unique_ptr<PagedCompiler> paged_compiler_;
};
} // namespace infinilm::engine
#pragma once
#include "../../models/infinilm_model.hpp"
#include "../rank_barrier.hpp"
namespace infinilm::engine {
class GraphCompiler {
public:
using Compiled = std::tuple<
std::shared_ptr<infinicore::graph::Graph>,
std::shared_ptr<InfinilmModel::Output>>;
explicit GraphCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : model_(model), barrier_(barrier) {}
virtual ~GraphCompiler() = default;
virtual void compile() = 0;
virtual Compiled get_compiled(const InfinilmModel::Input &input) = 0;
protected:
std::shared_ptr<InfinilmModel> model_;
RankBarrier *barrier_;
};
} // namespace infinilm::engine
#include "paged_compiler.hpp"
namespace {
// Todo: replace with Tensor::zeros when it is available
inline void set_zeros(infinicore::Tensor &tensor) {
std::vector<uint8_t> zeros(tensor->nbytes(), 0);
infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false);
}
} // namespace
namespace infinilm::engine {
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
for (size_t b = 1; b < 32; b++) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 32; b < 64; b += 8) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 64; b < 128; b += 16) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 128; b < 256; b += 32) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 256; b <= 512; b += 64) {
decode_batch_sizes_.push_back(b);
}
}
void PagedCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t nblocks = dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())->num_blocks();
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(block_tables_holder_);
for (size_t b : decode_batch_sizes_) {
size_t block_per_req = nblocks / b;
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.input_ids.value());
set_zeros(input.position_ids.value());
set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.input_offsets.value());
std::vector<int64_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) {
input_offsets_vec[i] = i;
}
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false);
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.slot_mapping.value());
barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
compiled_map_decode_[b] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
}
PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.block_tables.value()->size(0);
size_t block_per_req = input.block_tables.value()->size(1);
// only support decode only batch
if (batch_size != input.input_ids.value()->size(1)) {
return {nullptr, nullptr};
} else {
auto result = compiled_map_decode_.find(batch_size);
if (result == compiled_map_decode_.end()) {
return {nullptr, nullptr};
}
auto &graph_input = result->second.input;
graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
return std::make_tuple(graph, shared_output);
}
} else {
return {nullptr, nullptr};
}
}
} // namespace infinilm::engine
#pragma once
#include "graph_compiler.hpp"
#include <unordered_map>
namespace infinilm::engine {
class PagedCompiler : public GraphCompiler {
public:
PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
void compile() override;
Compiled get_compiled(const InfinilmModel::Input &input) override;
private:
std::vector<size_t> decode_batch_sizes_;
infinicore::Tensor block_tables_holder_;
struct CompiledResult {
InfinilmModel::Input input;
Compiled compiled;
};
std::unordered_map<
size_t, // num_requests
CompiledResult>
compiled_map_decode_;
};
} // namespace infinilm::engine
#include "static_batching_compiler.hpp"
#include "../../cache/cache.hpp"
namespace infinilm::engine {
StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
}
void StaticBatchingCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
size_t b = dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())->max_batch_size();
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
compiled_map_[std::make_tuple(b, 1)] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled(
const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.input_ids.value()->size(0);
size_t seqlen = input.input_ids.value()->size(1);
auto result = compiled_map_.find(std::make_tuple(batch_size, seqlen));
if (result == compiled_map_.end()) {
return std::make_tuple(nullptr, nullptr);
} else {
auto &graph_input = result->second.input;
graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.past_sequence_lengths.value()->copy_from(input.past_sequence_lengths.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
return std::make_tuple(graph, shared_output);
}
} else {
return std::make_tuple(nullptr, nullptr);
}
}
} // namespace infinilm::engine
#pragma once
#include "graph_compiler.hpp"
#include <unordered_map>
namespace infinilm::engine {
class StaticBatchingCompiler : public GraphCompiler {
public:
StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
void compile() override;
Compiled get_compiled(const InfinilmModel::Input &input) override;
private:
struct TupleHash {
size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
auto h1 = std::hash<size_t>{}(std::get<0>(t));
auto h2 = std::hash<size_t>{}(std::get<1>(t));
return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
}
};
struct CompiledResult {
InfinilmModel::Input input;
Compiled compiled;
};
std::unordered_map<
std::tuple<size_t, size_t>, // (batch_size, seq_len)
CompiledResult,
TupleHash>
compiled_map_;
};
} // namespace infinilm::engine
...@@ -6,26 +6,73 @@ namespace infinilm::engine { ...@@ -6,26 +6,73 @@ namespace infinilm::engine {
//------------------------------------------------------ //------------------------------------------------------
// Constructor // Constructor
//------------------------------------------------------ //------------------------------------------------------
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
InferEngine::InferEngine( InferEngine::InferEngine(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config, const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type, infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config) // Changed parameter const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter
: communication_group_(distributed_config, device_type), : communication_group_(distributed_config, device_type),
model_config_(config) { legacy_model_config_(config) {
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(
legacy_model_config_,
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling));
}
// Compile the model on all workers
this->compile();
}
InferEngine::InferEngine(
const std::string &model_path,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter
: communication_group_(distributed_config, device_type) {
if (cache_config != nullptr) { if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy(); cache_config_ = cache_config->unique_copy();
} }
// Load model config if model_path is provided, model_path must be valid, and config.json exists
this->model_config_ = std::make_shared<infinilm::config::ModelConfig>(model_path + "/config.json");
// Create one RankWorker per rank // Create one RankWorker per rank
int world_size = communication_group_.get_world_size(); int world_size = communication_group_.get_world_size();
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
workers_.reserve(world_size); workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) { for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>( workers_.emplace_back(std::make_unique<RankWorker>(
model_config_, model_config_,
communication_group_.get_rank_info(r), communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr)); cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling));
} }
// Compile the model on all workers
this->compile();
} }
//------------------------------------------------------ //------------------------------------------------------
...@@ -65,9 +112,9 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { ...@@ -65,9 +112,9 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
}; };
return { return {
input_ids, // @todo: on device in the future to_device(input_ids), // @todo: on device in the future
to_device(position_ids), to_device(position_ids),
past_sequence_lengths, // @todo: on device in the future to_device(past_sequence_lengths), // @todo: on device in the future
to_device(total_sequence_lengths), to_device(total_sequence_lengths),
to_device(input_offsets), to_device(input_offsets),
to_device(block_tables), to_device(block_tables),
...@@ -88,6 +135,16 @@ InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { ...@@ -88,6 +135,16 @@ InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
return workers_[0]->get_output(); return workers_[0]->get_output();
} }
void InferEngine::compile() {
for (auto &worker : workers_) {
worker->compile();
}
// Wait for all workers
for (auto &worker : workers_) {
worker->wait();
}
}
//------------------------------------------------------ //------------------------------------------------------
// Destructor // Destructor
//------------------------------------------------------ //------------------------------------------------------
...@@ -112,6 +169,8 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) { ...@@ -112,6 +169,8 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->wait(); worker->wait();
} }
this->compile();
} }
} // namespace infinilm::engine } // namespace infinilm::engine
#pragma once #pragma once
#include "../config/model_config.hpp"
#include "../models/infinilm_model.hpp" #include "../models/infinilm_model.hpp"
#include "../models/llama/llama_config.hpp" #include "../models/llama/llama_config.hpp"
#include "distributed/distributed.hpp" #include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "rank_barrier.hpp"
#include "rank_worker.hpp" #include "rank_worker.hpp"
#include <optional> #include <optional>
...@@ -18,11 +20,31 @@ public: ...@@ -18,11 +20,31 @@ public:
using Output = RankWorker::Output; using Output = RankWorker::Output;
// Updated constructor: accept CacheConfig instead of CacheType // Updated constructor: accept CacheConfig instead of CacheType
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
InferEngine( InferEngine(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(), const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr); const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false);
InferEngine(
const std::string &model_path = "",
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false);
// Load a parameter to all workers (each can extract its shard inside RankWorker) // Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param); void load_param(const std::string &name, const infinicore::Tensor &param);
...@@ -33,6 +55,8 @@ public: ...@@ -33,6 +55,8 @@ public:
// Run a single forward pass on all workers and return the outputs from all ranks // Run a single forward pass on all workers and return the outputs from all ranks
Output forward(const Input &input); Output forward(const Input &input);
void compile();
void reset_cache(const cache::CacheConfig *new_config); void reset_cache(const cache::CacheConfig *new_config);
~InferEngine(); ~InferEngine();
...@@ -44,9 +68,11 @@ public: ...@@ -44,9 +68,11 @@ public:
protected: protected:
std::vector<std::unique_ptr<RankWorker>> workers_; std::vector<std::unique_ptr<RankWorker>> workers_;
std::unique_ptr<RankBarrier> barrier_;
distributed::CommunicationGroup communication_group_; distributed::CommunicationGroup communication_group_;
const InfinilmModel::Config &model_config_;
std::unique_ptr<cache::CacheConfig> cache_config_; std::unique_ptr<cache::CacheConfig> cache_config_;
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
}; };
} // namespace infinilm::engine } // namespace infinilm::engine
#include "rank_barrier.hpp"
namespace infinilm::engine {
RankBarrier::RankBarrier(size_t num_ranks) : thread_count_(num_ranks), generation_(0), arrived_(0) {}
void RankBarrier::wait() {
std::unique_lock<std::mutex> lock(mutex_);
int gen = generation_;
if (++arrived_ == thread_count_) {
// last thread
generation_++;
arrived_ = 0;
cv_.notify_all();
} else {
cv_.wait(lock, [&] { return gen != generation_; });
}
}
} // namespace infinilm::engine
#pragma once
#include <condition_variable>
#include <mutex>
namespace infinilm::engine {
class RankBarrier {
public:
explicit RankBarrier(size_t nranks);
void wait();
private:
const size_t thread_count_;
size_t arrived_;
size_t generation_;
std::mutex mutex_;
std::condition_variable cv_;
};
} // namespace infinilm::engine
...@@ -10,17 +10,33 @@ ...@@ -10,17 +10,33 @@
namespace infinilm::engine { namespace infinilm::engine {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
RankWorker::RankWorker(const InfinilmModel::Config &model_config, RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config) const cache::CacheConfig *cache_config,
: model_config_(model_config), RankBarrier *barrier,
bool enable_graph_compiling)
: legacy_model_config_(model_config),
rank_info_(rank_info), rank_info_(rank_info),
enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT), job_cmd_(Command::INIT),
has_job_(false), has_job_(false),
job_done_(false), job_done_(false),
should_exit_(false), should_exit_(false),
init_done_(false), init_done_(false),
rng_(std::random_device{}()) { rng_(std::random_device{}()),
barrier_(barrier) {
if (cache_config != nullptr) { if (cache_config != nullptr) {
pending_cache_config_ = cache_config->unique_copy(); pending_cache_config_ = cache_config->unique_copy();
} }
...@@ -32,6 +48,32 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, ...@@ -32,6 +48,32 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
cv_.wait(lk, [&] { return init_done_; }); cv_.wait(lk, [&] { return init_done_; });
} }
RankWorker::RankWorker(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling)
: model_config_(model_config),
rank_info_(rank_info),
enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
should_exit_(false),
init_done_(false),
rng_(std::random_device{}()),
barrier_(barrier) {
if (cache_config != nullptr) {
pending_cache_config_ = cache_config->unique_copy();
}
// start the thread
thread_ = std::thread(&RankWorker::thread_loop, this);
// Wait until the worker thread finishes initialization (model created)
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_; });
}
std::string RankWorker::info() const { std::string RankWorker::info() const {
std::stringstream ss; std::stringstream ss;
...@@ -113,6 +155,21 @@ void RankWorker::run(const Input &args) { ...@@ -113,6 +155,21 @@ void RankWorker::run(const Input &args) {
cv_.notify_all(); cv_.notify_all();
} }
//------------------------------------------------------
// compile -- asynchronous
//------------------------------------------------------
void RankWorker::compile() {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot run");
}
job_cmd_ = Command::COMPILE;
has_job_ = true;
job_done_ = false;
cv_.notify_all();
}
//------------------------------------------------------ //------------------------------------------------------
// wait -- asynchronous // wait -- asynchronous
//------------------------------------------------------ //------------------------------------------------------
...@@ -176,10 +233,20 @@ void RankWorker::thread_loop() { ...@@ -176,10 +233,20 @@ void RankWorker::thread_loop() {
infinicore::context::setDevice(rank_info_.device); infinicore::context::setDevice(rank_info_.device);
// Create model using factory (may be expensive) // Create model using factory (may be expensive)
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); if (model_config_ == nullptr) {
model_ = InfinilmModelFactory::createModel(legacy_model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
} else {
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
}
if (!model_) { if (!model_) {
throw std::runtime_error("Failed to create model"); throw std::runtime_error("Failed to create model");
} }
if (enable_graph_compiling_) {
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_);
}
init_done_ = true; init_done_ = true;
} }
cv_.notify_all(); cv_.notify_all();
...@@ -223,12 +290,12 @@ void RankWorker::thread_loop() { ...@@ -223,12 +290,12 @@ void RankWorker::thread_loop() {
try { try {
model_->load_parameter(local_param_name, local_param); model_->load_parameter(local_param_name, local_param);
} catch (const std::exception &e) { } catch (const std::exception &e) {
// convert exceptions to a safe behavior: set should_exit_ and notify caller {
std::lock_guard<std::mutex> lk(mutex_); std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true; should_exit_ = true;
job_done_ = true; job_done_ = true;
}
cv_.notify_all(); cv_.notify_all();
// rethrow so the thread can be joined and caller sees an error if desired (optional)
spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what()); spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what());
break; break;
} }
...@@ -245,9 +312,21 @@ void RankWorker::thread_loop() { ...@@ -245,9 +312,21 @@ void RankWorker::thread_loop() {
{ {
std::lock_guard<std::mutex> lk(mutex_); std::lock_guard<std::mutex> lk(mutex_);
auto model_args = local_args.to_model_input(rank_info_.device); infinicore::Tensor logits;
// Forward calculation // Try to get compiled graph
auto logits{model_->forward(model_args).logits}; if (compiler_ != nullptr) {
auto [graph, output] = compiler_->get_compiled(local_args.to_model_input(infinicore::Device::cpu()));
if (graph != nullptr && output != nullptr) {
graph->run();
logits = output->logits;
}
}
// Fall back to eager mode
if (!logits) {
auto model_args = local_args.to_model_input(rank_info_.device);
logits = model_->forward(model_args).logits;
}
// Random sampling (rank 0 only) // Random sampling (rank 0 only)
if (rank_info_.tp_rank == 0) { if (rank_info_.tp_rank == 0) {
auto temperature{local_args.temperature}; auto temperature{local_args.temperature};
...@@ -286,9 +365,11 @@ void RankWorker::thread_loop() { ...@@ -286,9 +365,11 @@ void RankWorker::thread_loop() {
cv_.notify_all(); cv_.notify_all();
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::lock_guard<std::mutex> lk(mutex_); {
should_exit_ = true; std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true; should_exit_ = true;
job_done_ = true;
}
cv_.notify_all(); cv_.notify_all();
spdlog::error("[{}] exception during forward: {}\n", info(), e.what()); spdlog::error("[{}] exception during forward: {}\n", info(), e.what());
break; break;
...@@ -296,7 +377,6 @@ void RankWorker::thread_loop() { ...@@ -296,7 +377,6 @@ void RankWorker::thread_loop() {
} else if (local_cmd == Command::RESET_CACHE) { } else if (local_cmd == Command::RESET_CACHE) {
try { try {
model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr); model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr);
{ {
std::lock_guard<std::mutex> lk(mutex_); std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true; job_done_ = true;
...@@ -304,17 +384,44 @@ void RankWorker::thread_loop() { ...@@ -304,17 +384,44 @@ void RankWorker::thread_loop() {
cv_.notify_all(); cv_.notify_all();
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::lock_guard<std::mutex> lk(mutex_); {
should_exit_ = true; std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true; should_exit_ = true;
job_done_ = true;
}
cv_.notify_all(); cv_.notify_all();
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what()); spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
break; break;
} }
} else if (local_cmd == Command::COMPILE) {
try {
if (compiler_ != nullptr) {
compiler_->compile();
}
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
}
cv_.notify_all();
} catch (const std::exception &e) {
{
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
}
cv_.notify_all();
spdlog::error("[{}] exception during compile: {}\n", info(), e.what());
break;
}
} else { } else {
// Shouldn't reach here (no-op) // Shouldn't reach here (no-op)
} }
} // while } // while
// Some clean up should be done before exiting the thread
compiler_.reset();
} catch (const std::exception &e) { } catch (const std::exception &e) {
// Top-level exception: ensure any waiters are woken and the thread exits cleanly. // Top-level exception: ensure any waiters are woken and the thread exits cleanly.
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment