Unverified Commit 71c70586 authored by qinyiqun's avatar qinyiqun Committed by GitHub
Browse files

demo131 - multiple issues regarding quatization, qy, etc.



* issue/204 - support graph in server scripts

* issue/208 - adapt to ali ppu

* issue/194 - add quantization modify configs accordingly

支持nv w8 1batch 1tp

增加json支持

InfiniLM 增加量化层和global config

以一种比较优雅的方式增加了quant config的支持

修改部分代码结构,删除无用代码

跟随inifnicore修改

删除所有的model_config,统一使用global_config

跟随InfiniLM最新代码修改

修改函数参数顺序

改名global config 为model config

Refactor: add new API alongside legacy interfaces with deprecation warnings

添加w4 inifnicore相关内容,以及将Quantization config划入InfiniCore

添加w4 inifnicore相关内容,以及将Quantization config划入InfiniCore

* issue/175 - qy device support

qy_page_131: add qy device

success qy inference_server.py

* Issue/170 - Add HYGON support and improve device type handling.

* Issue/193: feats for deployment
Signed-off-by: default avatarCeng23333 <441651826@qq.com>

* skip responding eos token
Signed-off-by: default avatarCeng23333 <441651826@qq.com>

* issue/143 use add_rmsnorm, nt flash attn, nt kv caching

* issue/204 - support graph in server scripts

* issue/208 - adapt to ali ppu

* rebase main

* issue/216 feat: support static kv cache in server

* fix llm server cache config

* demo131 - resolve mishandled conflicts

* demo131 - further adjust attn and caching logic

* demo131 - resolve merge requirements

---------
Signed-off-by: default avatarCeng23333 <441651826@qq.com>
Co-authored-by: default avatarwooway777 <wooway777@gmail.com>
Co-authored-by: default avatarxgqdut2016 <kenan_gewei@163.com>
Co-authored-by: default avatargongchensu <zhuyue_134@qq.com>
Co-authored-by: default avatarCeng23333 <441651826@qq.com>
Co-authored-by: default avatarPanZezhong <panzezhong@qiyuanlab.com>
Co-authored-by: default avatarMaYuhang <2902139028@qq.com>
parent ee59b3f5
...@@ -28,10 +28,26 @@ public: ...@@ -28,10 +28,26 @@ public:
* @param config Model configuration * @param config Model configuration
* @param device Device to create tensors on * @param device Device to create tensors on
*/ */
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: compute language modeling logits * @brief Forward pass: compute language modeling logits
* *
...@@ -45,7 +61,6 @@ public: ...@@ -45,7 +61,6 @@ public:
const cache::CacheConfig *get_cache_config() const override; const cache::CacheConfig *get_cache_config() const override;
// Module information // Module information
const LlamaConfig &config() const { return model_->config(); }
LlamaModel &model() { return *model_; } LlamaModel &model() { return *model_; }
const LlamaModel &model() const { return *model_; } const LlamaModel &model() const { return *model_; }
......
...@@ -3,7 +3,18 @@ ...@@ -3,7 +3,18 @@
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
namespace infinilm::models::llama { namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaMLP::LlamaMLP(const LlamaConfig &config, LlamaMLP::LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
...@@ -22,6 +33,43 @@ LlamaMLP::LlamaMLP(const LlamaConfig &config, ...@@ -22,6 +33,43 @@ LlamaMLP::LlamaMLP(const LlamaConfig &config,
dtype, device, tp_rank, tp_size, rank_info.comm); dtype, device, tp_rank, tp_size, rank_info.comm);
} }
LlamaMLP::LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: model_config_(model_config), hidden_size_(model_config->get<size_t>("hidden_size")),
intermediate_size_(model_config->get<size_t>("intermediate_size")),
use_bias_(model_config->get_or<bool>("mlp_bias", false)), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
// Initialize projection layers
auto quant_scheme = this->model_config_->get_quant_scheme();
switch (quant_scheme) {
case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8:
INFINILM_GATE_UP_LINEAR_W8A8_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
case infinicore::quantization::QuantScheme::AWQ_W4A16:
INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
default:
INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
}
}
infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const { infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
// 1. Project to gate and up // 1. Project to gate and up
auto hidden_states_mutable = hidden_states; auto hidden_states_mutable = hidden_states;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../../layers/fused_linear.hpp" #include "../../layers/fused_linear.hpp"
#include "llama_config.hpp" #include "llama_config.hpp"
#include "../../config/model_config.hpp"
#include "infinicore/device.hpp" #include "infinicore/device.hpp"
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
...@@ -33,10 +34,26 @@ public: ...@@ -33,10 +34,26 @@ public:
* @param device Device to create tensors on * @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaMLP(const LlamaConfig &config, LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: compute MLP output * @brief Forward pass: compute MLP output
* *
...@@ -57,6 +74,8 @@ protected: ...@@ -57,6 +74,8 @@ protected:
size_t hidden_size_; size_t hidden_size_;
size_t intermediate_size_; size_t intermediate_size_;
bool use_bias_; bool use_bias_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -6,7 +6,18 @@ ...@@ -6,7 +6,18 @@
#include <iostream> #include <iostream>
namespace infinilm::models::llama { namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaModel::LlamaModel(const LlamaConfig &config, LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
...@@ -43,6 +54,39 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -43,6 +54,39 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
} }
} }
LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: model_config_(model_config), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
// Initialize token embeddings
INFINICORE_NN_MODULE_INIT(embed_tokens, model_config_->get<size_t>("vocab_size"), model_config_->get<size_t>("hidden_size"),
std::nullopt, dtype, device);
// Initialize decoder layers with layer indices
// TODO: Update INFINICORE_NN_MODULE_VEC_INIT macro to support per-layer constructor arguments
// (e.g., via a factory function or lambda that receives the layer index)
// Currently, we can't use the macro because each layer needs a different layer_idx
layers_.reserve(model_config_->get<size_t>("num_hidden_layers"));
for (size_t i = 0; i < model_config_->get<size_t>("num_hidden_layers"); ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), model_config_, device, i, rank_info));
}
// Initialize final layer normalization
INFINICORE_NN_MODULE_INIT(norm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
dtype, device);
// Initialize Rotary Position Embeddings (shared across all layers)
// Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing
INFINICORE_NN_MODULE_INIT(rotary_emb, model_config_->get_head_dim(), model_config_->get<size_t>("max_position_embeddings"),
model_config_->get<double>("rope_theta"), infinicore::nn::RoPE::Algo::GPT_NEOX,
dtype, device, model_config_->get_rope_scaling());
for (auto &layer : layers_) {
if (layer) {
layer->set_rotary_emb(rotary_emb_);
}
}
}
infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
...@@ -79,7 +123,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { ...@@ -79,7 +123,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
kv_cache_ = nullptr; kv_cache_ = nullptr;
return; return;
} }
if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) { if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config);
kv_cache_config && model_config_ == nullptr) {
kv_cache_ = std::make_shared<cache::StaticKVCache>( kv_cache_ = std::make_shared<cache::StaticKVCache>(
config_.head_dim, config_.head_dim,
config_.head_dim, config_.head_dim,
...@@ -90,8 +135,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { ...@@ -90,8 +135,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
config_.dtype, config_.dtype,
*kv_cache_config, *kv_cache_config,
rank_info_); rank_info_);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) { paged_kv_cache_config && model_config_ == nullptr) {
kv_cache_ = std::make_shared<cache::PagedKVCache>( kv_cache_ = std::make_shared<cache::PagedKVCache>(
config_.head_dim, config_.head_dim,
config_.head_dim, config_.head_dim,
...@@ -101,6 +146,27 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { ...@@ -101,6 +146,27 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
config_.dtype, config_.dtype,
*paged_kv_cache_config, *paged_kv_cache_config,
rank_info_); rank_info_);
} else if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
kv_cache_ = std::make_shared<cache::StaticKVCache>(
model_config_->get_head_dim(),
model_config_->get_head_dim(),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_hidden_layers"),
model_config_->get<size_t>("max_position_embeddings"),
model_config_->get_dtype(),
*kv_cache_config,
rank_info_);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) {
kv_cache_ = std::make_shared<cache::PagedKVCache>(
model_config_->get_head_dim(),
model_config_->get_head_dim(),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_hidden_layers"),
model_config_->get_dtype(),
*paged_kv_cache_config,
rank_info_);
} else { } else {
throw std::runtime_error("Unsupported cache type"); throw std::runtime_error("Unsupported cache type");
} }
......
#pragma once #pragma once
#include "../../cache/kv_cache.hpp" #include "../../cache/kv_cache.hpp"
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp" #include "llama_decoder_layer.hpp"
#include "infinicore/nn/embedding.hpp" #include "infinicore/nn/embedding.hpp"
...@@ -38,10 +37,26 @@ public: ...@@ -38,10 +37,26 @@ public:
* @param device Device to create tensors on * @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaModel(const LlamaConfig &config, LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: process input through the model * @brief Forward pass: process input through the model
* *
...@@ -64,8 +79,7 @@ public: ...@@ -64,8 +79,7 @@ public:
void reset_cache(const cache::CacheConfig *cache_config); void reset_cache(const cache::CacheConfig *cache_config);
// Module information // Module information
const LlamaConfig &config() const { return config_; } size_t num_layers() const { return model_config_->get<size_t>("num_hidden_layers"); }
size_t num_layers() const { return config_.num_hidden_layers; }
protected: protected:
// Token embeddings // Token embeddings
...@@ -86,6 +100,8 @@ protected: ...@@ -86,6 +100,8 @@ protected:
private: private:
LlamaConfig config_; LlamaConfig config_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -2,11 +2,22 @@ ...@@ -2,11 +2,22 @@
#include "llama/llama.hpp" #include "llama/llama.hpp"
namespace infinilm { namespace infinilm {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info, engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) { const cache::CacheConfig *cache) {
std::shared_ptr<InfinilmModel> model; std::shared_ptr<InfinilmModel> model;
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) { if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr; const auto &llama_config = *llama_config_ptr;
...@@ -22,4 +33,24 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( ...@@ -22,4 +33,24 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
return model; return model;
} }
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) {
std::shared_ptr<InfinilmModel> model;
if (true) {
model = std::make_shared<models::llama::LlamaForCausalLM>(
model_config, rank_info.device, rank_info);
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
if (cache) {
model->reset_cache(cache);
}
return model;
}
} // namespace infinilm } // namespace infinilm
#pragma once #pragma once
#include "../config/model_config.hpp"
#include "infinilm_model.hpp" #include "infinilm_model.hpp"
#include "../engine/distributed/distributed.hpp" #include "../engine/distributed/distributed.hpp"
...@@ -7,9 +8,26 @@ ...@@ -7,9 +8,26 @@
namespace infinilm { namespace infinilm {
class InfinilmModelFactory { class InfinilmModelFactory {
public: public:
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
static std::shared_ptr<InfinilmModel> createModel( static std::shared_ptr<InfinilmModel> createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr); const cache::CacheConfig *cache = nullptr);
static std::shared_ptr<InfinilmModel> createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr);
}; };
} // namespace infinilm } // namespace infinilm
...@@ -63,20 +63,52 @@ inline void bind_infer_engine(py::module &m) { ...@@ -63,20 +63,52 @@ inline void bind_infer_engine(py::module &m) {
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def( .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") .def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) {
self.reset_cache(cfg ? cfg.get() : nullptr);
},
py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) { .def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config(); auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
infer_engine
.def(py::init([](
const std::string &model_path,
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) {
return std::make_shared<InferEngine>(
model_path,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling);
}),
py::arg("model_path") = "",
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false)
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) {
py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) {
py::dict result;
for (const auto &[name, param] : state_dict_tp) {
result[py::cast(name)] = infinicore::Tensor(param);
}
state_dict_tp_all.append(result);
}
return state_dict_tp_all;
}) })
.def("__repr__", [](const InferEngine &self) { .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; .def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
}); .def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
py::class_<InferEngine::Input>(infer_engine, "Input") py::class_<InferEngine::Input>(infer_engine, "Input")
.def( .def(
......
...@@ -137,6 +137,11 @@ def get_args(): ...@@ -137,6 +137,11 @@ def get_args():
action="store_true", action="store_true",
help="Run nvidia test", help="Run nvidia test",
) )
parser.add_argument(
"--qy",
action="store_true",
help="Run qy test",
)
parser.add_argument( parser.add_argument(
"--metax", "--metax",
action="store_true", action="store_true",
...@@ -278,6 +283,13 @@ class TestModel: ...@@ -278,6 +283,13 @@ class TestModel:
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# token编码 # token编码
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -290,7 +302,16 @@ class TestModel: ...@@ -290,7 +302,16 @@ class TestModel:
] ]
# print(input_content, end="", flush=True) # print(input_content, end="", flush=True)
input_ids_list = tokenizer.batch_encode_plus(input_content)["input_ids"] # Support Transformers >= 5.0 for batch_encode_plus deprecation
encoding = tokenizer(
input_content,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
)
input_ids_list = encoding["input_ids"]
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -348,6 +369,8 @@ if __name__ == "__main__": ...@@ -348,6 +369,8 @@ if __name__ == "__main__":
device_str = "cpu" device_str = "cpu"
elif args.nvidia: elif args.nvidia:
device_str = "cuda" device_str = "cuda"
elif args.qy:
device_str = "cuda"
elif args.metax: elif args.metax:
device_str = "cuda" device_str = "cuda"
elif args.moore: elif args.moore:
......
...@@ -27,6 +27,11 @@ def get_args(): ...@@ -27,6 +27,11 @@ def get_args():
action="store_true", action="store_true",
help="Run nvidia test", help="Run nvidia test",
) )
parser.add_argument(
"--qy",
action="store_true",
help="Run qy test",
)
parser.add_argument( parser.add_argument(
"--metax", "--metax",
action="store_true", action="store_true",
...@@ -150,7 +155,6 @@ def test( ...@@ -150,7 +155,6 @@ def test(
distributed_config=DistConfig(tp), distributed_config=DistConfig(tp),
enable_graph_compiling=enable_graph, enable_graph_compiling=enable_graph,
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Load Weights # Load Weights
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -160,7 +164,6 @@ def test( ...@@ -160,7 +164,6 @@ def test(
# create tokenizer # create tokenizer
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if "llama" == model.config.model_type: if "llama" == model.config.model_type:
backend = getattr(tokenizer, "backend_tokenizer", None) backend = getattr(tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend) target = getattr(backend, "_tokenizer", backend)
...@@ -194,9 +197,19 @@ def test( ...@@ -194,9 +197,19 @@ def test(
for prompt in prompts for prompt in prompts
] ]
input_ids_list = tokenizer.batch_encode_plus(input_contents)[ # input_ids_list = tokenizer.batch_encode_plus(input_contents)[
"input_ids" # "input_ids"
] # List: [[1, 1128, 526, 366, 29892]] # ] # List: [[1, 1128, 526, 366, 29892]]
input_ids_list = [
tokenizer._encode_plus(
text,
truncation=True,
max_length=2048,
add_special_tokens=True
)["input_ids"]
for text in input_contents
]
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Create KVCache # Create KVCache
...@@ -254,6 +267,8 @@ if __name__ == "__main__": ...@@ -254,6 +267,8 @@ if __name__ == "__main__":
device_str = "cpu" device_str = "cpu"
elif args.nvidia: elif args.nvidia:
device_str = "cuda" device_str = "cuda"
elif args.qy:
device_str = "cuda"
elif args.metax: elif args.metax:
device_str = "cuda" device_str = "cuda"
elif args.moore: elif args.moore:
...@@ -268,7 +283,7 @@ if __name__ == "__main__": ...@@ -268,7 +283,7 @@ if __name__ == "__main__":
device_str = "cuda" device_str = "cuda"
else: else:
print( print(
"Usage: python examples/jiuge.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon | --ali | --hygon] --model_path=<path/to/model_dir>\n" "Usage: python examples/jiuge.py [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --cambricon | --ali | --hygon] --model_path=<path/to/model_dir>\n"
"such as, python examples/jiuge.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0" "such as, python examples/jiuge.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
) )
sys.exit(1) sys.exit(1)
......
...@@ -25,5 +25,11 @@ class AutoConfig: ...@@ -25,5 +25,11 @@ class AutoConfig:
config_dict["model_type"] == "qwen2" or config_dict["model_type"] == "qwen3" config_dict["model_type"] == "qwen2" or config_dict["model_type"] == "qwen3"
): ):
return LlamaConfig(**config_dict) return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "minicpm":
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "fm9g":
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "fm9g7b":
return LlamaConfig(**config_dict)
raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")
...@@ -35,14 +35,21 @@ class InferEngine(_infinilm.InferEngine): ...@@ -35,14 +35,21 @@ class InferEngine(_infinilm.InferEngine):
if device is None: if device is None:
device = infinicore.device() device = infinicore.device()
# super().__init__(
# self.config,
# distributed_config._underlying,
# device._underlying.type,
# cache_config,
# enable_graph_compiling,
# )
super().__init__( super().__init__(
self.config, model_path,
distributed_config._underlying, distributed_config._underlying,
device._underlying.type, device._underlying.type,
cache_config, cache_config,
enable_graph_compiling, enable_graph_compiling,
) )
self.use_cache = False self.use_cache = False
self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig) self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig)
......
...@@ -18,6 +18,7 @@ from infinilm.llm.llm import ( ...@@ -18,6 +18,7 @@ from infinilm.llm.llm import (
EngineConfig, EngineConfig,
) )
from infinilm.llm.scheduler import Scheduler, SchedulerOutput from infinilm.llm.scheduler import Scheduler, SchedulerOutput
from infinilm.llm.static_scheduler import StaticScheduler, StaticSchedulerOutput
from infinilm.llm.cache_manager import BlockManager, Block from infinilm.llm.cache_manager import BlockManager, Block
__all__ = [ __all__ = [
...@@ -38,6 +39,8 @@ __all__ = [ ...@@ -38,6 +39,8 @@ __all__ = [
# Internal (for advanced use) # Internal (for advanced use)
"Scheduler", "Scheduler",
"SchedulerOutput", "SchedulerOutput",
"StaticScheduler",
"StaticSchedulerOutput",
"BlockManager", "BlockManager",
"Block", "Block",
] ]
...@@ -23,10 +23,11 @@ from infinilm.llm.request import ( ...@@ -23,10 +23,11 @@ from infinilm.llm.request import (
) )
from infinilm.llm.sampling_params import SamplingParams from infinilm.llm.sampling_params import SamplingParams
from infinilm.llm.scheduler import Scheduler from infinilm.llm.scheduler import Scheduler
from infinilm.llm.static_scheduler import StaticScheduler
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
from infinilm.infer_engine import InferEngine from infinilm.infer_engine import InferEngine
from infinilm.cache.cache import PagedKVCacheConfig from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig
from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.modeling_utils import load_model_state_dict_by_file
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tokenizers import decoders as _dec from tokenizers import decoders as _dec
...@@ -43,10 +44,12 @@ class EngineConfig: ...@@ -43,10 +44,12 @@ class EngineConfig:
device: Device type string ('cpu', 'cuda', 'mlu', etc.). device: Device type string ('cpu', 'cuda', 'mlu', etc.).
dtype: Data type string ('float16', 'bfloat16', 'float32'). dtype: Data type string ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism. tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference. cache_type: Cache type ('paged' or 'static').
max_batch_size: Maximum batch size for inference (only for paged cache).
max_tokens: Default maximum tokens to generate. max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks. num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block. block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature. temperature: Default sampling temperature.
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
...@@ -57,10 +60,12 @@ class EngineConfig: ...@@ -57,10 +60,12 @@ class EngineConfig:
device: str = "cuda" device: str = "cuda"
dtype: str = "float16" dtype: str = "float16"
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
cache_type: str = "paged" # "paged" or "static"
max_batch_size: int = 16 max_batch_size: int = 16
max_tokens: int = 4096 max_tokens: int = 4096
num_blocks: int = 8 * 1024 num_blocks: int = 8 * 1024
block_size: int = 16 block_size: int = 16
max_cache_len: int = 4096
temperature: float = 1.0 temperature: float = 1.0
top_p: float = 0.8 top_p: float = 0.8
top_k: int = 1 top_k: int = 1
...@@ -76,17 +81,11 @@ class LLMEngine: ...@@ -76,17 +81,11 @@ class LLMEngine:
# Initialize device and dtype # Initialize device and dtype
self._init_device() self._init_device()
# Initialize KV cache
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
# Initialize model engine # Initialize model engine
self.model_engine = InferEngine( self.model_engine = InferEngine(
model_path=config.model_path, model_path=config.model_path,
device=self.device, device=self.device,
distributed_config=DistConfig(config.tensor_parallel_size), distributed_config=DistConfig(config.tensor_parallel_size),
cache_config=cache_config,
enable_graph_compiling=config.enable_graph, enable_graph_compiling=config.enable_graph,
) )
...@@ -101,12 +100,30 @@ class LLMEngine: ...@@ -101,12 +100,30 @@ class LLMEngine:
) )
self._fix_tokenizer_decoder() self._fix_tokenizer_decoder()
# Initialize scheduler # Initialize KV cache based on cache type
if config.cache_type == "static":
cache_config = StaticKVCacheConfig(
max_batch_size=1, max_cache_len=config.max_cache_len
)
self.scheduler = StaticScheduler(max_cache_len=config.max_cache_len)
logger.info(
f"Using Static KV Cache with max_cache_len={config.max_cache_len}"
)
elif config.cache_type == "paged":
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
self.scheduler = Scheduler( self.scheduler = Scheduler(
max_batch_size=config.max_batch_size, max_batch_size=config.max_batch_size,
num_blocks=config.num_blocks, num_blocks=config.num_blocks,
block_size=config.block_size, block_size=config.block_size,
) )
logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}")
else:
raise ValueError(f"Unsupported cache_type: {config.cache_type}")
self.model_engine.reset_cache(cache_config)
self.cache_type = config.cache_type
# Get EOS token IDs from model config # Get EOS token IDs from model config
self.eos_token_ids = self.model_engine.config.eos_token_id or [] self.eos_token_ids = self.model_engine.config.eos_token_id or []
...@@ -202,19 +219,21 @@ class LLMEngine: ...@@ -202,19 +219,21 @@ class LLMEngine:
"""Convert model input dict to infinicore tensors.""" """Convert model input dict to infinicore tensors."""
model_input = {} model_input = {}
for key, value in model_input_dict.items(): for key, value in model_input_dict.items():
if key == "input_ids": if value is None:
model_input[key] = infinicore.from_list([value], dtype=infinicore.int64) # Skip None values (block_tables/slot_mapping for static cache)
model_input[key] = None
elif key in [ elif key in [
"input_ids",
"position_ids", "position_ids",
"past_kv_lengths", "past_kv_lengths",
"total_kv_lengths", "total_kv_lengths",
"input_offsets", "input_offsets",
"slot_mapping", "slot_mapping",
"block_tables",
]: ]:
model_input[key] = infinicore.from_list(value, dtype=infinicore.int64) model_input[key] = infinicore.from_list(value, dtype=infinicore.int64)
elif key == "block_tables":
model_input[key] = infinicore.from_list(value, dtype=infinicore.int64)
else: else:
# temperature, top_k, top_p, etc.
model_input[key] = value model_input[key] = value
return model_input return model_input
...@@ -225,7 +244,8 @@ class LLMEngine: ...@@ -225,7 +244,8 @@ class LLMEngine:
sampled_tokens: List[int], sampled_tokens: List[int],
): ):
"""Update request status after inference step.""" """Update request status after inference step."""
if is_prefill: # Only reset req blocks for paged cache
if is_prefill and self.cache_type == "paged":
self.scheduler.cache_manager.reset_req_blocks() self.scheduler.cache_manager.reset_req_blocks()
for req, token_id in zip(requests, sampled_tokens): for req, token_id in zip(requests, sampled_tokens):
...@@ -293,7 +313,6 @@ class LLMEngine: ...@@ -293,7 +313,6 @@ class LLMEngine:
# Remove the stop string from the end # Remove the stop string from the end
req.generated_text = req.generated_text[: -len(stop_str)] req.generated_text = req.generated_text[: -len(stop_str)]
break break
# Put output in queue if it exists (for async streaming) # Put output in queue if it exists (for async streaming)
if req._output_queue is not None: if req._output_queue is not None:
output = TokenOutput( output = TokenOutput(
...@@ -363,10 +382,12 @@ class LLM: ...@@ -363,10 +382,12 @@ class LLM:
device: str = "cuda", device: str = "cuda",
dtype: str = "float16", dtype: str = "float16",
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
cache_type: str = "paged",
max_batch_size: int = 16, max_batch_size: int = 16,
max_tokens: int = 4096, max_tokens: int = 4096,
num_blocks: int = 8 * 1024, num_blocks: int = 8 * 1024,
block_size: int = 16, block_size: int = 16,
max_cache_len: int = 4096,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 0.8, top_p: float = 0.8,
top_k: int = 1, top_k: int = 1,
...@@ -379,10 +400,12 @@ class LLM: ...@@ -379,10 +400,12 @@ class LLM:
device: Device type ('cpu', 'cuda', 'mlu', 'moore'). device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32'). dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism. tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference. cache_type: Cache type ('paged' or 'static').
max_batch_size: Maximum batch size (only for paged cache).
max_tokens: Default maximum tokens to generate. max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks. num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block. block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature. temperature: Default sampling temperature.
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
...@@ -393,10 +416,12 @@ class LLM: ...@@ -393,10 +416,12 @@ class LLM:
device=device, device=device,
dtype=dtype, dtype=dtype,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
cache_type=cache_type,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_tokens=max_tokens, max_tokens=max_tokens,
num_blocks=num_blocks, num_blocks=num_blocks,
block_size=block_size, block_size=block_size,
max_cache_len=max_cache_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
...@@ -510,10 +535,12 @@ class AsyncLLMEngine: ...@@ -510,10 +535,12 @@ class AsyncLLMEngine:
device: str = "cuda", device: str = "cuda",
dtype: str = "float16", dtype: str = "float16",
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
cache_type: str = "paged",
max_batch_size: int = 16, max_batch_size: int = 16,
max_tokens: int = 512, max_tokens: int = 512,
num_blocks: int = 8 * 1024, num_blocks: int = 8 * 1024,
block_size: int = 16, block_size: int = 16,
max_cache_len: int = 4096,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 0.8, top_p: float = 0.8,
top_k: int = 1, top_k: int = 1,
...@@ -526,10 +553,12 @@ class AsyncLLMEngine: ...@@ -526,10 +553,12 @@ class AsyncLLMEngine:
device: Device type ('cpu', 'cuda', 'mlu', 'moore'). device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32'). dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism. tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference. cache_type: Cache type ('paged' or 'static').
max_batch_size: Maximum batch size (only for paged cache).
max_tokens: Default maximum tokens to generate. max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks. num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block. block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature. temperature: Default sampling temperature.
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
...@@ -540,10 +569,12 @@ class AsyncLLMEngine: ...@@ -540,10 +569,12 @@ class AsyncLLMEngine:
device=device, device=device,
dtype=dtype, dtype=dtype,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
cache_type=cache_type,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_tokens=max_tokens, max_tokens=max_tokens,
num_blocks=num_blocks, num_blocks=num_blocks,
block_size=block_size, block_size=block_size,
max_cache_len=max_cache_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
......
...@@ -103,7 +103,7 @@ class SchedulerOutput: ...@@ -103,7 +103,7 @@ class SchedulerOutput:
block_tables.append(padded_block_table) block_tables.append(padded_block_table)
return { return {
"input_ids": tokens, "input_ids": [tokens],
"position_ids": position_ids, "position_ids": position_ids,
"past_kv_lengths": cached_lens, "past_kv_lengths": cached_lens,
"total_kv_lengths": seq_lens, "total_kv_lengths": seq_lens,
...@@ -154,6 +154,10 @@ class Scheduler: ...@@ -154,6 +154,10 @@ class Scheduler:
req = self.waiting_queue.sync_q.get_nowait() req = self.waiting_queue.sync_q.get_nowait()
except queue.Empty: except queue.Empty:
break break
# Skip requests that were already finished (e.g., timed out/canceled while waiting)
if req.is_finished():
self.complete_requests([req])
continue
if not self.can_accept_request(req): if not self.can_accept_request(req):
self.waiting_queue.sync_q.put(req) self.waiting_queue.sync_q.put(req)
......
"""
Static Scheduler - Single-batch request scheduling for Static KV Cache.
"""
import logging
import queue
import janus
from typing import List, Optional
from infinilm.llm.request import RequestStatus, InferenceRequest, FinishReason
logger = logging.getLogger(__name__)
class StaticSchedulerOutput:
"""Static scheduler output containing single request and execution phase info."""
def __init__(
self,
scheduled_requests: List[InferenceRequest],
is_prefill: bool = False,
):
self.scheduled_requests = scheduled_requests
self.num_requests = len(scheduled_requests)
self.is_prefill = is_prefill
def build_model_inputs(
self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1
):
"""Construct model inputs for prefill or decode phase.
Static cache model inputs:
Prefill phase:
- input_ids: All prompt tokens [1, prompt_length]
- position_ids: [0, 1, 2, ..., prompt_length-1]
- past_kv_lengths: [0] (no cached tokens initially)
- total_kv_lengths: [prompt_length]
Decode phase:
- input_ids: Only the last generated token [1, 1]
- position_ids: [current_position] (position in full sequence)
- past_kv_lengths: [num_cached_tokens]
- total_kv_lengths: [total_tokens]
-
"""
req = self.scheduled_requests[0]
if self.is_prefill:
# Prefill: send all prompt tokens
tokens = req.get_input_tokens()
input_ids = [tokens]
position_ids = [list(range(len(tokens)))]
past_kv_len = 0
total_kv_len = len(tokens)
input_offsets = [0, len(tokens)]
else:
# Decode: send only the last generated token
last_token = req.generated_token_ids[-1]
current_position = req.get_total_length() - 1
input_ids = [[last_token]]
position_ids = [[current_position]]
past_kv_len = current_position
total_kv_len = req.get_total_length()
input_offsets = [0, 1]
return {
"input_ids": input_ids,
"position_ids": position_ids,
"past_kv_lengths": [past_kv_len],
"total_kv_lengths": [total_kv_len],
"input_offsets": input_offsets,
"block_tables": None,
"slot_mapping": None,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
}
class StaticScheduler:
"""Request scheduler for Static KV Cache with batch_size=1.
Simplified scheduling logic:
- Only handles one request at a time
- No cache block management needed
- Simple waiting queue for incoming requests
"""
def __init__(self, max_cache_len: int = 4096):
self.waiting_queue = janus.Queue()
self.running_request: Optional[InferenceRequest] = None
self.max_cache_len = max_cache_len
def add_request(self, request: InferenceRequest):
if request is not None:
request.status = RequestStatus.WAITING
self.waiting_queue.sync_q.put(request)
def schedule(self) -> Optional[StaticSchedulerOutput]:
"""Schedule and return single request to execute."""
while True:
# Case 1: Continue running request (decode phase)
if self.running_request is not None:
req = self.running_request
if req.is_finished():
self.running_request = None
continue
if req.get_total_length() > self.max_cache_len:
logger.warning(
f"Request {req.request_id} exceeds max_cache_len={self.max_cache_len}, "
"completing request."
)
self.running_request = None
req.mark_failed(FinishReason.LENGTH)
continue
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False)
# Case 2: Get new request from waiting queue (prefill phase)
try:
req = self.waiting_queue.sync_q.get_nowait()
except queue.Empty:
return None
if req.is_finished():
continue
prompt_len = req.get_prompt_length()
if prompt_len > self.max_cache_len:
logger.error(
f"Request {req.request_id} prompt length {prompt_len} "
f"exceeds max_cache_len={self.max_cache_len}. Request rejected."
)
req.mark_failed(FinishReason.LENGTH)
continue
req.status = RequestStatus.RUNNING
self.running_request = req
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=True)
def complete_requests(self, requests: List[InferenceRequest]):
"""Handle completed requests."""
for req in requests:
if req.is_finished() and req == self.running_request:
self.running_request = None
logger.debug(f"Completed request {req.request_id}")
def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return {
"max_cache_len": self.max_cache_len,
"running_request": (
self.running_request.request_id if self.running_request else None
),
"waiting_queue_size": self.waiting_queue.sync_q.qsize(),
}
...@@ -75,7 +75,7 @@ def load_state_dict( ...@@ -75,7 +75,7 @@ def load_state_dict(
) )
for k in f.keys(): for k in f.keys():
state_dict[k] = f.get_tensor(k).to(device=device, dtype=dtype) state_dict[k] = f.get_tensor(k).to(device=device)
return state_dict return state_dict
...@@ -155,7 +155,6 @@ def load_model_state_dict_by_file( ...@@ -155,7 +155,6 @@ def load_model_state_dict_by_file(
model_param_infini = {} model_param_infini = {}
for key in model_param.keys(): for key in model_param.keys():
model_param_infini[key] = infinicore.from_torch(model_param[key]) model_param_infini[key] = infinicore.from_torch(model_param[key])
model.load_state_dict(model_param_infini, strict=False) model.load_state_dict(model_param_infini, strict=False)
infinicore.sync_device() infinicore.sync_device()
...@@ -168,7 +167,6 @@ def load_model_state_dict_by_file( ...@@ -168,7 +167,6 @@ def load_model_state_dict_by_file(
model_param_infini[key] = infinicore.from_torch( model_param_infini[key] = infinicore.from_torch(
model_params[key].to(dtype=torch_dtype) model_params[key].to(dtype=torch_dtype)
) )
already_loaded_keys.append(key) already_loaded_keys.append(key)
model.load_state_dict(model_param_infini, strict=True) model.load_state_dict(model_param_infini, strict=True)
......
...@@ -50,6 +50,42 @@ def chunk_json( ...@@ -50,6 +50,42 @@ def chunk_json(
} }
def completion_json(
id_,
content,
role="assistant",
finish_reason="stop",
model: str = "unknown",
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
):
"""Generate JSON response for non-streaming completion."""
return {
"id": id_,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"system_fingerprint": None,
"choices": [
{
"index": 0,
"message": {
"role": role,
"content": content,
},
"logprobs": None,
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
},
}
class InferenceServer: class InferenceServer:
"""HTTP server for LLM inference.""" """HTTP server for LLM inference."""
...@@ -59,10 +95,12 @@ class InferenceServer: ...@@ -59,10 +95,12 @@ class InferenceServer:
device: str = "cuda", device: str = "cuda",
dtype: str = "float16", dtype: str = "float16",
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
cache_type: str = "paged",
max_tokens: int = 4096, max_tokens: int = 4096,
max_batch_size: int = 16, max_batch_size: int = 16,
num_blocks: int = 8 * 1024, num_blocks: int = 8 * 1024,
block_size: int = 16, block_size: int = 16,
max_cache_len: int = 4096,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 0.8, top_p: float = 0.8,
top_k: int = 1, top_k: int = 1,
...@@ -77,10 +115,12 @@ class InferenceServer: ...@@ -77,10 +115,12 @@ class InferenceServer:
device: Device type ('cpu', 'cuda', 'mlu', 'moore'). device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32'). dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism. tensor_parallel_size: Number of devices for tensor parallelism.
cache_type: Cache type ('paged' or 'static').
max_tokens: Default maximum tokens to generate. max_tokens: Default maximum tokens to generate.
max_batch_size: Maximum batch size for inference. max_batch_size: Maximum batch size for inference (only for paged cache).
num_blocks: Number of KV cache blocks. num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block. block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature. temperature: Default sampling temperature.
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
...@@ -94,10 +134,12 @@ class InferenceServer: ...@@ -94,10 +134,12 @@ class InferenceServer:
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.tensor_parallel_size = tensor_parallel_size self.tensor_parallel_size = tensor_parallel_size
self.cache_type = cache_type
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.block_size = block_size self.block_size = block_size
self.max_cache_len = max_cache_len
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
...@@ -124,10 +166,12 @@ class InferenceServer: ...@@ -124,10 +166,12 @@ class InferenceServer:
device=self.device, device=self.device,
dtype=self.dtype, dtype=self.dtype,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
cache_type=self.cache_type,
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_size=self.block_size, block_size=self.block_size,
max_cache_len=self.max_cache_len,
temperature=self.temperature, temperature=self.temperature,
top_p=self.top_p, top_p=self.top_p,
top_k=self.top_k, top_k=self.top_k,
...@@ -396,12 +440,15 @@ class InferenceServer: ...@@ -396,12 +440,15 @@ class InferenceServer:
output_text = output_text.strip() output_text = output_text.strip()
finish_reason = self._convert_finish_reason(req.finish_reason) finish_reason = self._convert_finish_reason(req.finish_reason)
response = chunk_json( response = completion_json(
request_id, request_id,
content=output_text, content=output_text,
role="assistant", role="assistant",
finish_reason=finish_reason or "stop", finish_reason=finish_reason or "stop",
model=self.model_id, model=self.model_id,
prompt_tokens=req.get_prompt_length(),
completion_tokens=req.get_num_generated_tokens(),
total_tokens=req.get_total_length(),
) )
return response return response
...@@ -450,6 +497,13 @@ def parse_args(): ...@@ -450,6 +497,13 @@ def parse_args():
"--model_path", type=str, required=True, help="Path to model directory" "--model_path", type=str, required=True, help="Path to model directory"
) )
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism degree") parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism degree")
parser.add_argument(
"--cache_type",
type=str,
default="paged",
choices=["paged", "static"],
help="Cache type: paged or static",
)
parser.add_argument( parser.add_argument(
"--max_tokens", "--max_tokens",
type=int, type=int,
...@@ -457,13 +511,28 @@ def parse_args(): ...@@ -457,13 +511,28 @@ def parse_args():
help="Maximum number of tokens to generate", help="Maximum number of tokens to generate",
) )
parser.add_argument( parser.add_argument(
"--max_batch_size", type=int, default=8, help="Maximum batch size" "--max_batch_size",
type=int,
default=8,
help="Maximum batch size (paged cache only)",
) )
parser.add_argument( parser.add_argument(
"--num_blocks", type=int, default=8 * 1024, help="Number of blocks for KV cache" "--num_blocks",
type=int,
default=8 * 1024,
help="Number of blocks for KV cache (paged cache only)",
)
parser.add_argument(
"--block_size",
type=int,
default=16,
help="Block size for KV cache (paged cache only)",
) )
parser.add_argument( parser.add_argument(
"--block_size", type=int, default=16, help="Block size for KV cache" "--max_cache_len",
type=int,
default=4096,
help="Maximum sequence length (static cache only)",
) )
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
...@@ -483,6 +552,7 @@ def parse_args(): ...@@ -483,6 +552,7 @@ def parse_args():
parser.add_argument("--port", type=int, default=8000, help="Server port") parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--cpu", action="store_true", help="Use CPU") parser.add_argument("--cpu", action="store_true", help="Use CPU")
parser.add_argument("--nvidia", action="store_true", help="Use NVIDIA GPU") parser.add_argument("--nvidia", action="store_true", help="Use NVIDIA GPU")
parser.add_argument("--qy", action="store_true", help="Use QY GPU")
parser.add_argument("--metax", action="store_true", help="Use MetaX device") parser.add_argument("--metax", action="store_true", help="Use MetaX device")
parser.add_argument("--moore", action="store_true", help="Use Moore device") parser.add_argument("--moore", action="store_true", help="Use Moore device")
parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device") parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device")
...@@ -513,6 +583,8 @@ def main(): ...@@ -513,6 +583,8 @@ def main():
device = "cpu" device = "cpu"
elif args.nvidia: elif args.nvidia:
device = "cuda" device = "cuda"
elif args.qy:
device = "cuda"
elif args.metax: elif args.metax:
device = "cuda" device = "cuda"
elif args.moore: elif args.moore:
...@@ -525,7 +597,7 @@ def main(): ...@@ -525,7 +597,7 @@ def main():
device = "cuda" device = "cuda"
else: else:
print( print(
"Usage: python infinilm.server.inference_server [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon | --ali] " "Usage: python infinilm.server.inference_server [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --cambricon | --ali] "
"--model_path=<path/to/model_dir> --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH_SIZE" "--model_path=<path/to/model_dir> --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH_SIZE"
"\n" "\n"
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ " "Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
...@@ -540,10 +612,12 @@ def main(): ...@@ -540,10 +612,12 @@ def main():
device=device, device=device,
dtype=args.dtype, dtype=args.dtype,
tensor_parallel_size=args.tp, tensor_parallel_size=args.tp,
cache_type=args.cache_type,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
num_blocks=args.num_blocks, num_blocks=args.num_blocks,
block_size=args.block_size, block_size=args.block_size,
max_cache_len=args.max_cache_len,
temperature=args.temperature, temperature=args.temperature,
top_p=args.top_p, top_p=args.top_p,
top_k=args.top_k, top_k=args.top_k,
......
...@@ -81,7 +81,6 @@ std::shared_ptr<Tensor> Loader::get(const std::string &name, int rank) { ...@@ -81,7 +81,6 @@ std::shared_ptr<Tensor> Loader::get(const std::string &name, int rank) {
__C void __C void
loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) { loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) {
std::string name_str(name); std::string name_str(name);
// std::cout << "Loading weight: " << name_str << std::endl;
auto weights = reinterpret_cast<infinicore::weights::Loader *>(weights_); auto weights = reinterpret_cast<infinicore::weights::Loader *>(weights_);
weights->load(name_str, data); weights->load(name_str, data);
} }
...@@ -9,7 +9,6 @@ from infinilm.modeling_utils import load_model_state_dict_by_file ...@@ -9,7 +9,6 @@ from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.infer_engine import GenerationConfig, InferEngine from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
from datasets import load_dataset, Dataset from datasets import load_dataset, Dataset
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment