Unverified Commit a4ced800 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #205 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 96ecf490 04c37f3f
#pragma once #pragma once
#include "../cache/cache.hpp" #include "../cache/cache.hpp"
#include "../config/model_config.hpp"
#include "../models/model_factory.hpp" #include "../models/model_factory.hpp"
#include "compiler/general_compiler.hpp"
#include "distributed/distributed.hpp" #include "distributed/distributed.hpp"
#include "rank_barrier.hpp"
#include <any> #include <any>
#include <condition_variable> #include <condition_variable>
...@@ -20,6 +23,7 @@ class RankWorker { ...@@ -20,6 +23,7 @@ class RankWorker {
LOAD, LOAD,
RUN, RUN,
RESET_CACHE, RESET_CACHE,
COMPILE,
STOP STOP
}; };
...@@ -55,7 +59,15 @@ public: ...@@ -55,7 +59,15 @@ public:
RankWorker(const InfinilmModel::Config &model_config, RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config); const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling);
RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling);
// Submit a parameter load job and wait until the load completes on the worker thread. // Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name, void load_param(const std::string &name,
...@@ -70,6 +82,9 @@ public: ...@@ -70,6 +82,9 @@ public:
// Reset the internal cache with a new configuration // Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig *new_config); void reset_cache(const cache::CacheConfig *new_config);
// Compile the model graph if enabled.
void compile();
// Wait until run job completes. The result can be retrieved with get_output(). // Wait until run job completes. The result can be retrieved with get_output().
void wait(); void wait();
...@@ -86,11 +101,16 @@ private: ...@@ -86,11 +101,16 @@ private:
private: private:
// Worker properties // Worker properties
const InfinilmModel::Config &model_config_; const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
distributed::RankInfo rank_info_; distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_; std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::Cache> cache_; std::shared_ptr<cache::Cache> cache_;
// Graph Compiling
bool enable_graph_compiling_;
std::unique_ptr<GraphCompiler> compiler_;
// Command for the pending job (protected by mutex_) // Command for the pending job (protected by mutex_)
Command job_cmd_; Command job_cmd_;
...@@ -116,6 +136,8 @@ private: ...@@ -116,6 +136,8 @@ private:
// Random // Random
std::mt19937 rng_; std::mt19937 rng_;
RankBarrier *barrier_;
}; };
} // namespace infinilm::engine } // namespace infinilm::engine
...@@ -6,6 +6,18 @@ namespace infinilm::layers { ...@@ -6,6 +6,18 @@ namespace infinilm::layers {
// --------------------------------------------------------- // ---------------------------------------------------------
// QKV Parallel Linear // QKV Parallel Linear
// --------------------------------------------------------- // ---------------------------------------------------------
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
QKVParallelLinear::QKVParallelLinear(size_t hidden_size, QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t head_dim, size_t head_dim,
size_t num_q_head, size_t num_q_head,
...@@ -57,6 +69,61 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size, ...@@ -57,6 +69,61 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
v_out_size_ = num_v_head_ * v_dim_ / tp_size_; v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
} }
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t head_dim,
size_t num_q_head,
size_t num_kv_head,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
bool bias,
const infinicore::DataType &dtype,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: QKVParallelLinear(hidden_size,
head_dim, head_dim, head_dim,
num_q_head, num_kv_head, num_kv_head,
bias, bias, bias,
quantization,
dtype, device, rank_info) {}
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t q_dim, size_t k_dim, size_t v_dim,
size_t num_q_head, size_t num_k_head, size_t num_v_head,
bool q_bias, bool k_bias, bool v_bias,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: infinicore::nn::ColumnParallelLinear(
hidden_size,
num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
quantization,
(q_bias || k_bias || v_bias),
dtype,
device,
rank_info.tp_rank,
rank_info.tp_size),
q_dim_(q_dim),
k_dim_(k_dim),
v_dim_(v_dim),
num_q_head_(num_q_head),
num_k_head_(num_k_head),
num_v_head_(num_v_head),
q_bias_(q_bias),
k_bias_(k_bias),
v_bias_(v_bias) {
if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) {
throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size");
}
if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) {
throw std::runtime_error("q_bias, k_bias, v_bias must all match");
}
q_out_size_ = num_q_head_ * q_dim_ / tp_size_;
k_out_size_ = num_k_head_ * k_dim_ / tp_size_;
v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
}
std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor> std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
QKVParallelLinear::forward_split(infinicore::Tensor &input) { QKVParallelLinear::forward_split(infinicore::Tensor &input) {
auto output = this->forward(input); auto output = this->forward(input);
...@@ -86,6 +153,40 @@ infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const { ...@@ -86,6 +153,40 @@ infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const {
0, tp_rank_, tp_size_); 0, tp_rank_, tp_size_);
} }
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale() const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_weight_scale() const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{0, q_out_size_, k_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale() const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros() const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_weight_zeros() const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{0, q_out_size_, k_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_weight_zeros() const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_bias() const { infinicore::nn::Parameter QKVParallelLinear::get_q_bias() const {
if (!q_bias_) { if (!q_bias_) {
return infinicore::nn::Parameter(); return infinicore::nn::Parameter();
...@@ -120,6 +221,18 @@ bool QKVParallelLinear::has_v_bias() const { return v_bias_; } ...@@ -120,6 +221,18 @@ bool QKVParallelLinear::has_v_bias() const { return v_bias_; }
// --------------------------------------------------------- // ---------------------------------------------------------
// Gate-Up Parallel Linear // Gate-Up Parallel Linear
// --------------------------------------------------------- // ---------------------------------------------------------
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias, GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias,
const infinicore::DataType &dtype, const infinicore::Device &device, const infinicore::DataType &dtype, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
...@@ -135,6 +248,22 @@ GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermedia ...@@ -135,6 +248,22 @@ GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermedia
} }
} }
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias,
const infinicore::DataType &dtype, const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, quantization, dtype, device, rank_info) {
}
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype, const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, quantization, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) {
if (gate_bias_ != up_bias_) {
throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time");
}
}
std::tuple<infinicore::Tensor, infinicore::Tensor> GateUpParallelLinear::forward_split(infinicore::Tensor &input) { std::tuple<infinicore::Tensor, infinicore::Tensor> GateUpParallelLinear::forward_split(infinicore::Tensor &input) {
auto output = this->forward(input); auto output = this->forward(input);
auto cols = output->shape()[2]; auto cols = output->shape()[2];
...@@ -168,6 +297,22 @@ infinicore::nn::Parameter GateUpParallelLinear::get_up_bias() const { ...@@ -168,6 +297,22 @@ infinicore::nn::Parameter GateUpParallelLinear::get_up_bias() const {
} }
} }
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_scale() const {
return infinicore::nn::Parameter(weight_scale_->narrow({{0, 0, weight_scale_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_scale() const {
return infinicore::nn::Parameter(weight_scale_->narrow({{0, weight_scale_->size(0) / 2, weight_scale_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_zeros() const {
return infinicore::nn::Parameter(weight_zeros_->narrow({{0, 0, weight_zeros_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_zeros() const {
return infinicore::nn::Parameter(weight_zeros_->narrow({{0, weight_zeros_->size(0) / 2, weight_zeros_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
bool GateUpParallelLinear::has_gate_bias() const { bool GateUpParallelLinear::has_gate_bias() const {
return gate_bias_; return gate_bias_;
} }
......
#pragma once #pragma once
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/quantization.hpp"
#include "../engine/distributed/communication_group.hpp" #include "../engine/distributed/communication_group.hpp"
...@@ -23,6 +24,25 @@ public: ...@@ -23,6 +24,25 @@ public:
const infinicore::Device &device = infinicore::Device(), const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
explicit QKVParallelLinear(size_t hidden_size,
size_t q_dim, size_t k_dim, size_t v_dim,
size_t num_q_head, size_t num_k_head, size_t num_v_head,
bool q_bias, bool k_bias, bool v_bias,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
// A more common case where all heads have the same dimension
explicit QKVParallelLinear(size_t hidden_size,
size_t head_dim,
size_t num_q_head, size_t num_kv_head,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
bool bias = false,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor> std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
forward_split(infinicore::Tensor &input); forward_split(infinicore::Tensor &input);
...@@ -30,6 +50,14 @@ public: ...@@ -30,6 +50,14 @@ public:
infinicore::nn::Parameter get_k_weight() const; infinicore::nn::Parameter get_k_weight() const;
infinicore::nn::Parameter get_v_weight() const; infinicore::nn::Parameter get_v_weight() const;
infinicore::nn::Parameter get_q_weight_scale() const;
infinicore::nn::Parameter get_k_weight_scale() const;
infinicore::nn::Parameter get_v_weight_scale() const;
infinicore::nn::Parameter get_q_weight_zeros() const;
infinicore::nn::Parameter get_k_weight_zeros() const;
infinicore::nn::Parameter get_v_weight_zeros() const;
infinicore::nn::Parameter get_q_bias() const; infinicore::nn::Parameter get_q_bias() const;
infinicore::nn::Parameter get_k_bias() const; infinicore::nn::Parameter get_k_bias() const;
infinicore::nn::Parameter get_v_bias() const; infinicore::nn::Parameter get_v_bias() const;
...@@ -55,6 +83,18 @@ private: ...@@ -55,6 +83,18 @@ private:
class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {
public: public:
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias = false, GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias = false,
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(), const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
...@@ -63,14 +103,33 @@ public: ...@@ -63,14 +103,33 @@ public:
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(), const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
bool bias = false,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
std::tuple<infinicore::Tensor, infinicore::Tensor> forward_split(infinicore::Tensor &input); std::tuple<infinicore::Tensor, infinicore::Tensor> forward_split(infinicore::Tensor &input);
infinicore::nn::Parameter get_gate_weight() const; infinicore::nn::Parameter get_gate_weight() const;
infinicore::nn::Parameter get_gate_weight_scale() const;
infinicore::nn::Parameter get_gate_weight_zeros() const;
infinicore::nn::Parameter get_gate_bias() const; infinicore::nn::Parameter get_gate_bias() const;
infinicore::nn::Parameter get_up_weight() const; infinicore::nn::Parameter get_up_weight() const;
infinicore::nn::Parameter get_up_weight_scale() const;
infinicore::nn::Parameter get_up_weight_zeros() const;
infinicore::nn::Parameter get_up_bias() const; infinicore::nn::Parameter get_up_bias() const;
bool has_gate_bias() const; bool has_gate_bias() const;
...@@ -103,4 +162,62 @@ private: ...@@ -103,4 +162,62 @@ private:
if (name##_->has_up_bias()) \ if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias()); this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
// ========================= QKV Quantization ==================================
#define INFINILM_QKV_LINEAR_W8A8_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \
this->register_parameter(std::string(q_name) + ".weight_scale", name##_->get_q_weight_scale()); \
this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \
this->register_parameter(std::string(k_name) + ".weight_scale", name##_->get_k_weight_scale()); \
this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight()); \
this->register_parameter(std::string(v_name) + ".weight_scale", name##_->get_v_weight_scale()); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
if (name##_->has_v_bias()) \
this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());
#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight()); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros()); \
this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale()); \
this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight()); \
this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros()); \
this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale()); \
this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight()); \
this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros()); \
this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale()); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
if (name##_->has_v_bias()) \
this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());
// ========================= Gate-Up Quantization ==============================
#define INFINILM_GATE_UP_LINEAR_W8A8_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \
this->register_parameter(std::string(gate_name) + ".weight_scale", name##_->get_gate_weight_scale()); \
this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight()); \
this->register_parameter(std::string(up_name) + ".weight_scale", name##_->get_up_weight_scale()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight()); \
this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale()); \
this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros()); \
this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight()); \
this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale()); \
this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
} // namespace infinilm::layers } // namespace infinilm::layers
#pragma once #pragma once
#include "infinicore/nn/module.hpp"
#include "../cache/cache.hpp" #include "../cache/cache.hpp"
#include "infinicore/nn/module.hpp"
#include "nlohmann/json.hpp"
#include <any> #include <any>
...@@ -13,7 +13,6 @@ class InfinilmModel : public infinicore::nn::Module { ...@@ -13,7 +13,6 @@ class InfinilmModel : public infinicore::nn::Module {
public: public:
struct Config { struct Config {
std::string model_type; std::string model_type;
virtual ~Config() = default; virtual ~Config() = default;
}; };
...@@ -43,5 +42,6 @@ public: ...@@ -43,5 +42,6 @@ public:
virtual Output forward(const Input &input) const = 0; virtual Output forward(const Input &input) const = 0;
virtual void reset_cache(const cache::CacheConfig *cache_config) = 0; virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
virtual const cache::CacheConfig *get_cache_config() const = 0;
}; };
} // namespace infinilm } // namespace infinilm
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
* - LlamaForCausalLM: Complete model with language modeling head * - LlamaForCausalLM: Complete model with language modeling head
*/ */
#include "llama_config.hpp" #include "../../config/model_config.hpp"
#include "llama_attention.hpp" #include "llama_attention.hpp"
#include "llama_mlp.hpp"
#include "llama_decoder_layer.hpp" #include "llama_decoder_layer.hpp"
#include "llama_model.hpp"
#include "llama_for_causal_lm.hpp" #include "llama_for_causal_lm.hpp"
#include "llama_mlp.hpp"
#include "llama_model.hpp"
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <iostream>
#include <optional> #include <optional>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
...@@ -17,6 +16,18 @@ ...@@ -17,6 +16,18 @@
namespace infinilm::models::llama { namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaAttention::LlamaAttention(const LlamaConfig &config, LlamaAttention::LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
...@@ -61,6 +72,65 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -61,6 +72,65 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
} }
} }
LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info)
: model_config_(model_config),
layer_idx_(layer_idx),
hidden_size_(model_config->get<size_t>("hidden_size")),
num_attention_heads_(model_config->get<size_t>("num_attention_heads")),
num_key_value_heads_(model_config->get<size_t>("num_key_value_heads")),
head_dim_(model_config->get_head_dim()),
kv_dim_(model_config->get_kv_dim()),
use_bias_(model_config->get_or<bool>("attention_bias", true)),
use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)),
max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")),
rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
int num_attention_heads = model_config_->get<size_t>("num_attention_heads");
int num_key_value_heads = model_config_->get<size_t>("num_key_value_heads");
if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
this->num_attention_heads_ = num_attention_heads / tp_size;
this->num_key_value_heads_ = num_key_value_heads / tp_size;
} else {
throw std::runtime_error("num_attention_heads / tp_size error.");
}
scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));
auto quant_scheme = this->model_config_->get_quant_scheme();
switch (quant_scheme) {
case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8:
INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info);
INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
case infinicore::quantization::QuantScheme::AWQ_W4A16:
INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info);
INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
default:
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info);
INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
}
if (model_config_->get<std::string>("model_type") == "qwen3") {
INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
}
}
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
...@@ -75,7 +145,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -75,7 +145,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
// 1. Project Q, K, V // 1. Project Q, K, V
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
if (use_qk_norm_) { if (use_qk_norm_ || model_config_->get_or<std::string>("model_type", "None") == "qwen3") {
q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_})); q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_}));
k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_})); k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_}));
} }
...@@ -112,8 +182,8 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -112,8 +182,8 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim] q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim]
auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor k_total; // [bs, n_kv_head, max_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor v_total; // [bs, n_kv_head, max_seq_len, head_dim]
if (kv_cache == nullptr) { if (kv_cache == nullptr) {
k_total = k_permuted; k_total = k_permuted;
v_total = v_permuted; v_total = v_permuted;
...@@ -124,7 +194,18 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -124,7 +194,18 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
} else { } else {
throw std::runtime_error("LlamaAttention: Unsupported kvcache type"); throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
} }
auto total_seq_len = k_total->shape()[2];
infinicore::Tensor attn_output;
if (false) {
// experimental nineoothed flash attention
attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scaling_, true);
attn_output = attn_output->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
} else {
size_t total_seq_len = reinterpret_cast<int64_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
// 6. Compute attention // 6. Compute attention
size_t ngroup = num_attention_heads_ / num_key_value_heads_; size_t ngroup = num_attention_heads_ / num_key_value_heads_;
...@@ -141,10 +222,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ...@@ -141,10 +222,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim] auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
auto attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_}) attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
->permute({0, 2, 1, 3}) ->permute({0, 2, 1, 3})
->contiguous() ->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
}
auto output = o_proj_->forward(attn_output); auto output = o_proj_->forward(attn_output);
...@@ -184,7 +266,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd ...@@ -184,7 +266,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_}); auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_}); auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});
if (use_qk_norm_) { if (use_qk_norm_ || model_config_->get_or<std::string>("model_type", "None") == "qwen3") {
q_reshaped = q_norm_->forward(q_reshaped); q_reshaped = q_norm_->forward(q_reshaped);
k_reshaped = k_norm_->forward(k_reshaped); k_reshaped = k_norm_->forward(k_reshaped);
} }
......
#pragma once #pragma once
#include "../../cache/kv_cache.hpp" #include "../../cache/kv_cache.hpp"
#include "../../config/model_config.hpp"
#include "../../engine/distributed/distributed.hpp" #include "../../engine/distributed/distributed.hpp"
#include "../../layers/fused_linear.hpp" #include "../../layers/fused_linear.hpp"
#include "llama_config.hpp" #include "llama_config.hpp"
...@@ -36,11 +37,28 @@ public: ...@@ -36,11 +37,28 @@ public:
* @param layer_idx Layer index for cache access * @param layer_idx Layer index for cache access
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaAttention(const LlamaConfig &config, LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: compute attention * @brief Forward pass: compute attention
* *
...@@ -101,6 +119,7 @@ protected: ...@@ -101,6 +119,7 @@ protected:
std::shared_ptr<infinicore::nn::RoPE> rotary_emb_; std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;
private: private:
std::shared_ptr<infinilm::config::ModelConfig> model_config_ = std::make_shared<infinilm::config::ModelConfig>();
size_t layer_idx_; // Layer index for cache access size_t layer_idx_; // Layer index for cache access
size_t hidden_size_; size_t hidden_size_;
size_t num_attention_heads_; size_t num_attention_heads_;
...@@ -109,7 +128,7 @@ private: ...@@ -109,7 +128,7 @@ private:
size_t kv_dim_; size_t kv_dim_;
bool use_bias_; // Bias for Q/K/V projections bool use_bias_; // Bias for Q/K/V projections
bool use_output_bias_; // Bias for output projection (o_proj) bool use_output_bias_; // Bias for output projection (o_proj)
bool use_qk_norm_; // Whether to use QK RMSNorm bool use_qk_norm_ = false; // Whether to use QK RMSNorm
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_; float scaling_;
......
#include "llama_decoder_layer.hpp" #include "llama_decoder_layer.hpp"
#include "infinicore/nn/rmsnorm.hpp" #include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include <optional> #include <optional>
namespace infinilm::models::llama { namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
...@@ -23,7 +33,25 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, ...@@ -23,7 +33,25 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_); INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
} }
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
// Initialize layer normalization layers
INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
dtype, device);
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
dtype, device);
// Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_);
INFINICORE_NN_MODULE_INIT(mlp, model_config_, device, rank_info_);
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states,
infinicore::Tensor &residual,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
...@@ -31,30 +59,19 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s ...@@ -31,30 +59,19 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
std::optional<infinicore::Tensor> input_offsets, std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables, std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const { std::optional<infinicore::Tensor> slot_mapping) const {
// Save residual for attention // 1. Attention layer normalization
auto residual = hidden_states; input_layernorm_->forward_inplace(hidden_states, residual);
// 1. Pre-attention layer normalization // 2. Self-attention
auto normed_states = input_layernorm_->forward(hidden_states); hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output);
// Save residual for MLP
residual = output;
// 3. Post-attention layer normalization // 3. Post-attention layer normalization
normed_states = post_attention_layernorm_->forward(output); post_attention_layernorm_->forward_inplace(hidden_states, residual);
// 4. MLP with residual connection
auto mlp_output = mlp_->forward(normed_states);
// Add residual: output = output + mlp_output // 4. MLP
output = infinicore::op::add(residual, mlp_output); hidden_states = mlp_->forward(hidden_states);
return output; return std::make_tuple(hidden_states, residual);
} }
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -33,20 +33,41 @@ public: ...@@ -33,20 +33,41 @@ public:
* @param layer_idx Layer index for cache management and debugging * @param layer_idx Layer index for cache management and debugging
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaDecoderLayer(const LlamaConfig &config, LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: process one decoder layer * @brief Forward pass: process one decoder layer
* *
* @param hidden_states Input tensor of shape [batch, seq_len, hidden_size] * @param hidden_states [batch, seq_len, hidden_size], will be modified
* @param residual [batch, seq_len, hidden_size], will be modified
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional KV cache for incremental decoding * @param kv_cache Optional KV cache for incremental decoding
* @return Output tensor of shape [batch, seq_len, hidden_size] * @return Output tensor of shape [batch, seq_len, hidden_size]
* Updated residual tensor of shape [batch, seq_len, hidden_size]
*/ */
infinicore::Tensor forward(const infinicore::Tensor &hidden_states, std::tuple<infinicore::Tensor, infinicore::Tensor>
forward(infinicore::Tensor &hidden_states,
infinicore::Tensor &residual,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache, std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
...@@ -75,6 +96,7 @@ protected: ...@@ -75,6 +96,7 @@ protected:
INFINICORE_NN_MODULE(LlamaAttention, self_attn); INFINICORE_NN_MODULE(LlamaAttention, self_attn);
INFINICORE_NN_MODULE(LlamaMLP, mlp); INFINICORE_NN_MODULE(LlamaMLP, mlp);
engine::distributed::RankInfo rank_info_; engine::distributed::RankInfo rank_info_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
private: private:
size_t layer_idx_; // Layer index for cache management and debugging size_t layer_idx_; // Layer index for cache management and debugging
......
...@@ -2,19 +2,26 @@ ...@@ -2,19 +2,26 @@
#include "infinicore/context/context.hpp" #include "infinicore/context/context.hpp"
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include <iostream>
namespace infinilm::models::llama { namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) { engine::distributed::RankInfo rank_info) {
// Initialize module's device_ member // Initialize module's device_ member
device_ = device; device_ = device;
const auto &dtype{config.dtype}; const auto &dtype{config.dtype};
// Initialize base model // Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, rank_info); INFINICORE_NN_MODULE_INIT(model, config, device, rank_info);
...@@ -25,6 +32,24 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, ...@@ -25,6 +32,24 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
dtype, device); dtype, device);
} }
LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info) {
// Initialize module's device_ member
device_ = device;
const auto &dtype{model_config->get_dtype()};
// Initialize base model
INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info);
// Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens
// For now, we create a separate linear layer
INFINICORE_NN_MODULE_INIT(lm_head, model_config->get<size_t>("hidden_size"), model_config->get<size_t>("vocab_size"), false,
dtype, device);
}
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto input_ids = input.input_ids.value(); auto input_ids = input.input_ids.value();
auto position_ids = input.position_ids.value(); auto position_ids = input.position_ids.value();
...@@ -40,12 +65,16 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { ...@@ -40,12 +65,16 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
// 2. Apply language modeling head to get logits // 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states); auto logits = lm_head_->forward(hidden_states);
return {logits}; return {logits};
} }
void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) {
model_->reset_cache(cache_config); cache_config_ = cache_config->unique_copy();
model_->reset_cache(cache_config_.get());
}
const cache::CacheConfig *LlamaForCausalLM::get_cache_config() const {
return cache_config_.get();
} }
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -28,10 +28,26 @@ public: ...@@ -28,10 +28,26 @@ public:
* @param config Model configuration * @param config Model configuration
* @param device Device to create tensors on * @param device Device to create tensors on
*/ */
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: compute language modeling logits * @brief Forward pass: compute language modeling logits
* *
...@@ -42,8 +58,9 @@ public: ...@@ -42,8 +58,9 @@ public:
void reset_cache(const cache::CacheConfig *cache_config) override; void reset_cache(const cache::CacheConfig *cache_config) override;
const cache::CacheConfig *get_cache_config() const override;
// Module information // Module information
const LlamaConfig &config() const { return model_->config(); }
LlamaModel &model() { return *model_; } LlamaModel &model() { return *model_; }
const LlamaModel &model() const { return *model_; } const LlamaModel &model() const { return *model_; }
...@@ -53,6 +70,8 @@ protected: ...@@ -53,6 +70,8 @@ protected:
// Language modeling head // Language modeling head
INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head); INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head);
std::unique_ptr<cache::CacheConfig> cache_config_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -3,7 +3,18 @@ ...@@ -3,7 +3,18 @@
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
namespace infinilm::models::llama { namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaMLP::LlamaMLP(const LlamaConfig &config, LlamaMLP::LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
...@@ -22,6 +33,43 @@ LlamaMLP::LlamaMLP(const LlamaConfig &config, ...@@ -22,6 +33,43 @@ LlamaMLP::LlamaMLP(const LlamaConfig &config,
dtype, device, tp_rank, tp_size, rank_info.comm); dtype, device, tp_rank, tp_size, rank_info.comm);
} }
LlamaMLP::LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: model_config_(model_config), hidden_size_(model_config->get<size_t>("hidden_size")),
intermediate_size_(model_config->get<size_t>("intermediate_size")),
use_bias_(model_config->get_or<bool>("mlp_bias", false)), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
// Initialize projection layers
auto quant_scheme = this->model_config_->get_quant_scheme();
switch (quant_scheme) {
case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8:
INFINILM_GATE_UP_LINEAR_W8A8_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
case infinicore::quantization::QuantScheme::AWQ_W4A16:
INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
default:
INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
}
}
infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const { infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
// 1. Project to gate and up // 1. Project to gate and up
auto hidden_states_mutable = hidden_states; auto hidden_states_mutable = hidden_states;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../../layers/fused_linear.hpp" #include "../../layers/fused_linear.hpp"
#include "llama_config.hpp" #include "llama_config.hpp"
#include "../../config/model_config.hpp"
#include "infinicore/device.hpp" #include "infinicore/device.hpp"
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
...@@ -33,10 +34,26 @@ public: ...@@ -33,10 +34,26 @@ public:
* @param device Device to create tensors on * @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaMLP(const LlamaConfig &config, LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: compute MLP output * @brief Forward pass: compute MLP output
* *
...@@ -57,6 +74,8 @@ protected: ...@@ -57,6 +74,8 @@ protected:
size_t hidden_size_; size_t hidden_size_;
size_t intermediate_size_; size_t intermediate_size_;
bool use_bias_; bool use_bias_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -6,7 +6,18 @@ ...@@ -6,7 +6,18 @@
#include <iostream> #include <iostream>
namespace infinilm::models::llama { namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaModel::LlamaModel(const LlamaConfig &config, LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
...@@ -43,6 +54,39 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -43,6 +54,39 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
} }
} }
LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: model_config_(model_config), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
// Initialize token embeddings
INFINICORE_NN_MODULE_INIT(embed_tokens, model_config_->get<size_t>("vocab_size"), model_config_->get<size_t>("hidden_size"),
std::nullopt, dtype, device);
// Initialize decoder layers with layer indices
// TODO: Update INFINICORE_NN_MODULE_VEC_INIT macro to support per-layer constructor arguments
// (e.g., via a factory function or lambda that receives the layer index)
// Currently, we can't use the macro because each layer needs a different layer_idx
layers_.reserve(model_config_->get<size_t>("num_hidden_layers"));
for (size_t i = 0; i < model_config_->get<size_t>("num_hidden_layers"); ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), model_config_, device, i, rank_info));
}
// Initialize final layer normalization
INFINICORE_NN_MODULE_INIT(norm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
dtype, device);
// Initialize Rotary Position Embeddings (shared across all layers)
// Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing
INFINICORE_NN_MODULE_INIT(rotary_emb, model_config_->get_head_dim(), model_config_->get<size_t>("max_position_embeddings"),
model_config_->get<double>("rope_theta"), infinicore::nn::RoPE::Algo::GPT_NEOX,
dtype, device, model_config_->get_rope_scaling());
for (auto &layer : layers_) {
if (layer) {
layer->set_rotary_emb(rotary_emb_);
}
}
}
infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> past_sequence_lengths, std::optional<infinicore::Tensor> past_sequence_lengths,
...@@ -55,11 +99,23 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -55,11 +99,23 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// 2. Process through all decoder layers // 2. Process through all decoder layers
size_t num_layers = layers_.size(); size_t num_layers = layers_.size();
infinicore::Tensor residual;
for (size_t i = 0; i < num_layers; ++i) { for (size_t i = 0; i < num_layers; ++i) {
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping); layers_.at(i)->forward(
hidden_states,
residual,
position_ids,
kv_cache_,
past_sequence_lengths,
total_sequence_lengths,
input_offsets,
block_tables,
slot_mapping);
} }
return norm_->forward(hidden_states); norm_->forward_inplace(hidden_states, residual);
return hidden_states;
} }
void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
...@@ -67,7 +123,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { ...@@ -67,7 +123,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
kv_cache_ = nullptr; kv_cache_ = nullptr;
return; return;
} }
if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) { if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config);
kv_cache_config && model_config_ == nullptr) {
kv_cache_ = std::make_shared<cache::StaticKVCache>( kv_cache_ = std::make_shared<cache::StaticKVCache>(
config_.head_dim, config_.head_dim,
config_.head_dim, config_.head_dim,
...@@ -78,8 +135,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { ...@@ -78,8 +135,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
config_.dtype, config_.dtype,
*kv_cache_config, *kv_cache_config,
rank_info_); rank_info_);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) { paged_kv_cache_config && model_config_ == nullptr) {
kv_cache_ = std::make_shared<cache::PagedKVCache>( kv_cache_ = std::make_shared<cache::PagedKVCache>(
config_.head_dim, config_.head_dim,
config_.head_dim, config_.head_dim,
...@@ -89,6 +146,27 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { ...@@ -89,6 +146,27 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
config_.dtype, config_.dtype,
*paged_kv_cache_config, *paged_kv_cache_config,
rank_info_); rank_info_);
} else if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
kv_cache_ = std::make_shared<cache::StaticKVCache>(
model_config_->get_head_dim(),
model_config_->get_head_dim(),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_hidden_layers"),
model_config_->get<size_t>("max_position_embeddings"),
model_config_->get_dtype(),
*kv_cache_config,
rank_info_);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) {
kv_cache_ = std::make_shared<cache::PagedKVCache>(
model_config_->get_head_dim(),
model_config_->get_head_dim(),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_hidden_layers"),
model_config_->get_dtype(),
*paged_kv_cache_config,
rank_info_);
} else { } else {
throw std::runtime_error("Unsupported cache type"); throw std::runtime_error("Unsupported cache type");
} }
......
#pragma once #pragma once
#include "../../cache/kv_cache.hpp" #include "../../cache/kv_cache.hpp"
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp" #include "llama_decoder_layer.hpp"
#include "infinicore/nn/embedding.hpp" #include "infinicore/nn/embedding.hpp"
...@@ -38,10 +37,26 @@ public: ...@@ -38,10 +37,26 @@ public:
* @param device Device to create tensors on * @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaModel(const LlamaConfig &config, LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: process input through the model * @brief Forward pass: process input through the model
* *
...@@ -64,8 +79,7 @@ public: ...@@ -64,8 +79,7 @@ public:
void reset_cache(const cache::CacheConfig *cache_config); void reset_cache(const cache::CacheConfig *cache_config);
// Module information // Module information
const LlamaConfig &config() const { return config_; } size_t num_layers() const { return model_config_->get<size_t>("num_hidden_layers"); }
size_t num_layers() const { return config_.num_hidden_layers; }
protected: protected:
// Token embeddings // Token embeddings
...@@ -86,6 +100,8 @@ protected: ...@@ -86,6 +100,8 @@ protected:
private: private:
LlamaConfig config_; LlamaConfig config_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -2,11 +2,22 @@ ...@@ -2,11 +2,22 @@
#include "llama/llama.hpp" #include "llama/llama.hpp"
namespace infinilm { namespace infinilm {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info, engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) { const cache::CacheConfig *cache) {
std::shared_ptr<InfinilmModel> model; std::shared_ptr<InfinilmModel> model;
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) { if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr; const auto &llama_config = *llama_config_ptr;
...@@ -22,4 +33,24 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( ...@@ -22,4 +33,24 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
return model; return model;
} }
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) {
std::shared_ptr<InfinilmModel> model;
if (true) {
model = std::make_shared<models::llama::LlamaForCausalLM>(
model_config, rank_info.device, rank_info);
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
if (cache) {
model->reset_cache(cache);
}
return model;
}
} // namespace infinilm } // namespace infinilm
#pragma once #pragma once
#include "../config/model_config.hpp"
#include "infinilm_model.hpp" #include "infinilm_model.hpp"
#include "../engine/distributed/distributed.hpp" #include "../engine/distributed/distributed.hpp"
...@@ -7,9 +8,26 @@ ...@@ -7,9 +8,26 @@
namespace infinilm { namespace infinilm {
class InfinilmModelFactory { class InfinilmModelFactory {
public: public:
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
static std::shared_ptr<InfinilmModel> createModel( static std::shared_ptr<InfinilmModel> createModel(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr); const cache::CacheConfig *cache = nullptr);
static std::shared_ptr<InfinilmModel> createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr);
}; };
} // namespace infinilm } // namespace infinilm
...@@ -35,17 +35,20 @@ inline void bind_infer_engine(py::module &m) { ...@@ -35,17 +35,20 @@ inline void bind_infer_engine(py::module &m) {
const InfinilmModel::Config &cfg, const InfinilmModel::Config &cfg,
const distributed::DistConfig &dist, const distributed::DistConfig &dist,
infinicore::Device::Type dev, infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg) { std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) {
return std::make_shared<InferEngine>( return std::make_shared<InferEngine>(
cfg, cfg,
dist, dist,
dev, dev,
cache_cfg ? cache_cfg.get() : nullptr); cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling);
}), }),
py::arg("config"), py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(), py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none()) py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false)
.def("load_param", &InferEngine::load_param, .def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"), py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)") "Load a parameter tensor into all workers (each worker picks its shard)")
...@@ -60,20 +63,52 @@ inline void bind_infer_engine(py::module &m) { ...@@ -60,20 +63,52 @@ inline void bind_infer_engine(py::module &m) {
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def( .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") .def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) {
self.reset_cache(cfg ? cfg.get() : nullptr);
},
py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) { .def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config(); auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
infer_engine
.def(py::init([](
const std::string &model_path,
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) {
return std::make_shared<InferEngine>(
model_path,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling);
}),
py::arg("model_path") = "",
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false)
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) {
py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) {
py::dict result;
for (const auto &[name, param] : state_dict_tp) {
result[py::cast(name)] = infinicore::Tensor(param);
}
state_dict_tp_all.append(result);
}
return state_dict_tp_all;
}) })
.def("__repr__", [](const InferEngine &self) { .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; .def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
}); .def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
py::class_<InferEngine::Input>(infer_engine, "Input") py::class_<InferEngine::Input>(infer_engine, "Input")
.def( .def(
......
...@@ -3,7 +3,7 @@ from transformers import AutoTokenizer ...@@ -3,7 +3,7 @@ from transformers import AutoTokenizer
from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
import argparse import argparse
import sys import sys
import time import time
...@@ -137,11 +137,36 @@ def get_args(): ...@@ -137,11 +137,36 @@ def get_args():
action="store_true", action="store_true",
help="Run nvidia test", help="Run nvidia test",
) )
parser.add_argument(
"--qy",
action="store_true",
help="Run qy test",
)
parser.add_argument(
"--metax",
action="store_true",
help="Run metax test",
)
parser.add_argument(
"--moore",
action="store_true",
help="Run moore test",
)
parser.add_argument(
"--iluvatar",
action="store_true",
help="Run iluvatar test",
)
parser.add_argument( parser.add_argument(
"--cambricon", "--cambricon",
action="store_true", action="store_true",
help="Run cambricon test", help="Run cambricon test",
) )
parser.add_argument(
"--ali",
action="store_true",
help="Run alippu test",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
...@@ -199,7 +224,21 @@ def get_args(): ...@@ -199,7 +224,21 @@ def get_args():
default=1.0, default=1.0,
help="sampling temperature", help="sampling temperature",
) )
parser.add_argument(
"--enable-paged-attn",
action="store_true",
help="use paged cache",
)
parser.add_argument(
"--enable-graph",
action="store_true",
help="enable graph compiling",
)
parser.add_argument(
"--warmup",
action="store_true",
help="Perform a warmup run before benchmarking/inference."
)
return parser.parse_args() return parser.parse_args()
...@@ -223,6 +262,8 @@ class TestModel: ...@@ -223,6 +262,8 @@ class TestModel:
infini_device=infinicore.device("cpu", 0), infini_device=infinicore.device("cpu", 0),
tp=1, tp=1,
skip_load=False, skip_load=False,
cache_config=None,
enable_graph=False,
) -> None: ) -> None:
model_path = os.path.expanduser(model_path) model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -232,6 +273,8 @@ class TestModel: ...@@ -232,6 +273,8 @@ class TestModel:
model_path, model_path,
device=infini_device, device=infini_device,
distributed_config=DistConfig(tp), distributed_config=DistConfig(tp),
cache_config=cache_config,
enable_graph_compiling=enable_graph,
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -245,6 +288,13 @@ class TestModel: ...@@ -245,6 +288,13 @@ class TestModel:
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# token编码 # token编码
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -257,7 +307,16 @@ class TestModel: ...@@ -257,7 +307,16 @@ class TestModel:
] ]
# print(input_content, end="", flush=True) # print(input_content, end="", flush=True)
input_ids_list = tokenizer.batch_encode_plus(input_content)["input_ids"] # Support Transformers >= 5.0 for batch_encode_plus deprecation
encoding = tokenizer(
input_content,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
)
input_ids_list = encoding["input_ids"]
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -315,8 +374,18 @@ if __name__ == "__main__": ...@@ -315,8 +374,18 @@ if __name__ == "__main__":
device_str = "cpu" device_str = "cpu"
elif args.nvidia: elif args.nvidia:
device_str = "cuda" device_str = "cuda"
elif args.qy:
device_str = "cuda"
elif args.metax:
device_str = "cuda"
elif args.moore:
device_str = "musa"
elif args.iluvatar:
device_str = "cuda"
elif args.cambricon: elif args.cambricon:
device_str = "mlu" device_str = "mlu"
elif args.ali:
device_str = "cuda"
else: else:
print( print(
"python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50" "python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50"
...@@ -336,6 +405,8 @@ if __name__ == "__main__": ...@@ -336,6 +405,8 @@ if __name__ == "__main__":
batch_size = args.batch_size batch_size = args.batch_size
input_len = args.input_len input_len = args.input_len
output_len = args.output_len output_len = args.output_len
enable_paged_attn = args.enable_paged_attn
enable_graph = args.enable_graph
if isinstance(batch_size, int): if isinstance(batch_size, int):
batch_size = [batch_size] batch_size = [batch_size]
...@@ -350,15 +421,81 @@ if __name__ == "__main__": ...@@ -350,15 +421,81 @@ if __name__ == "__main__":
# -------------------------------------------------------- # # -------------------------------------------------------- #
# 测试 # 测试
# -------------------------------------------------------- # # -------------------------------------------------------- #
# print("=================== start test ====================", type(batch_size)) if enable_paged_attn:
paged_kv_block_size = 16
max_num_blocks = max(
[
((c_["input_len"] + c_["output_len"] + 15) // 16) * c_["batch_size"]
for _, c_ in cases_dict.items()
]
)
cache_config = PagedKVCacheConfig(max_num_blocks, paged_kv_block_size)
else:
cache_config = None
test = TestModel( test = TestModel(
model_path, model_path,
infini_device=infini_device, infini_device=infini_device,
tp=tp, tp=tp,
skip_load=skip_load, skip_load=skip_load,
cache_config=cache_config,
enable_graph=enable_graph,
)
# ---------------------------------------------------------------------------- #
# Warmup
# ---------------------------------------------------------------------------- #
if args.warmup:
warmup_steps = 1
# warmup cache capacity
warmup_cache_len = 128
warmup_batch = len(test.input_ids_list)
test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=warmup_batch,
max_cache_len=warmup_cache_len,
)
)
avg_prompt_len = min(
64,
max(len(ids) for ids in test.input_ids_list)
)
warmup_ids = [
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
for ids in test.input_ids_list
]
input_ids_infini = infinicore.from_list(warmup_ids)
print("=================== warmup start ===================")
for _ in range(warmup_steps):
_ = test.model.generate(
input_ids_infini,
GenerationConfig(
max_new_tokens=5, # decode kernel warmup
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
),
_measure_and_log_time=False,
) )
print("=================== warmup done ====================")
# reset cache back to benchmark config
if cache_config is not None:
test.model.reset_cache(cache_config)
# ---------------------------------------------------------------------------- #
# Warmup done
# ---------------------------------------------------------------------------- #
for idx, case in tqdm(cases_dict.items(), desc="Processing cases"): for idx, case in tqdm(cases_dict.items(), desc="Processing cases"):
tqdm.write(f"\033[92mProcessing : {case}\033[0m") tqdm.write(f"\033[92mProcessing : {case}\033[0m")
...@@ -366,7 +503,8 @@ if __name__ == "__main__": ...@@ -366,7 +503,8 @@ if __name__ == "__main__":
input_len = case["input_len"] input_len = case["input_len"]
output_len = case["output_len"] output_len = case["output_len"]
# reset cache for each case if not enable_paged_attn:
# reset cache if static kvcache is used
initial_capacity = input_len + output_len initial_capacity = input_len + output_len
test.model.reset_cache( test.model.reset_cache(
StaticKVCacheConfig( StaticKVCacheConfig(
......
import infinicore import infinicore
import transformers
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tokenizers import decoders as _dec from tokenizers import decoders as _dec
from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.modeling_utils import load_model_state_dict_by_file
...@@ -10,6 +11,7 @@ import time ...@@ -10,6 +11,7 @@ import time
import os import os
import numpy as np import numpy as np
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from packaging import version
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
...@@ -27,6 +29,11 @@ def get_args(): ...@@ -27,6 +29,11 @@ def get_args():
action="store_true", action="store_true",
help="Run nvidia test", help="Run nvidia test",
) )
parser.add_argument(
"--qy",
action="store_true",
help="Run qy test",
)
parser.add_argument( parser.add_argument(
"--metax", "--metax",
action="store_true", action="store_true",
...@@ -47,6 +54,11 @@ def get_args(): ...@@ -47,6 +54,11 @@ def get_args():
action="store_true", action="store_true",
help="Run cambricon test", help="Run cambricon test",
) )
parser.add_argument(
"--ali",
action="store_true",
help="Run alippu test",
)
parser.add_argument( parser.add_argument(
"--hygon", "--hygon",
action="store_true", action="store_true",
...@@ -93,6 +105,11 @@ def get_args(): ...@@ -93,6 +105,11 @@ def get_args():
action="store_true", action="store_true",
help="use paged cache", help="use paged cache",
) )
parser.add_argument(
"--enable-graph",
action="store_true",
help="enable graph compiling",
)
parser.add_argument( parser.add_argument(
"--top-k", "--top-k",
...@@ -125,6 +142,7 @@ def test( ...@@ -125,6 +142,7 @@ def test(
infini_device=infinicore.device("cpu", 0), infini_device=infinicore.device("cpu", 0),
tp=1, tp=1,
enable_paged_attn=False, enable_paged_attn=False,
enable_graph=False,
top_k=1, top_k=1,
top_p=1.0, top_p=1.0,
temperature=1.0, temperature=1.0,
...@@ -137,8 +155,8 @@ def test( ...@@ -137,8 +155,8 @@ def test(
model_path, model_path,
device=infini_device, device=infini_device,
distributed_config=DistConfig(tp), distributed_config=DistConfig(tp),
enable_graph_compiling=enable_graph,
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Load Weights # Load Weights
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -148,7 +166,6 @@ def test( ...@@ -148,7 +166,6 @@ def test(
# create tokenizer # create tokenizer
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if "llama" == model.config.model_type: if "llama" == model.config.model_type:
backend = getattr(tokenizer, "backend_tokenizer", None) backend = getattr(tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend) target = getattr(backend, "_tokenizer", backend)
...@@ -182,9 +199,24 @@ def test( ...@@ -182,9 +199,24 @@ def test(
for prompt in prompts for prompt in prompts
] ]
input_ids_list = tokenizer.batch_encode_plus(input_contents)[ # input_ids_list = tokenizer.batch_encode_plus(input_contents)[
"input_ids" # "input_ids"
] # List: [[1, 1128, 526, 366, 29892]] # ] # List: [[1, 1128, 526, 366, 29892]]
if version.parse(transformers.__version__) < version.parse("5.0.0"):
# Ideally this is solved by upgrading transformers. However, doing so causes version mismatch between transformers and mlu pytorch on devices with Phytium CPU. So a branch is temporarily used.
input_ids_list = [
tokenizer.encode_plus(
text, truncation=True, max_length=2048, add_special_tokens=True
)["input_ids"]
for text in input_contents
]
else:
input_ids_list = [
tokenizer._encode_plus(
text, truncation=True, max_length=2048, add_special_tokens=True
)["input_ids"]
for text in input_contents
]
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Create KVCache # Create KVCache
...@@ -193,7 +225,7 @@ def test( ...@@ -193,7 +225,7 @@ def test(
batch_size = 1 if prompts is str else len(prompts) batch_size = 1 if prompts is str else len(prompts)
max_total_tokens = max_new_tokens + len(input_ids_list[0]) max_total_tokens = max_new_tokens + len(input_ids_list[0])
cache_config = PagedKVCacheConfig( cache_config = PagedKVCacheConfig(
num_blocks=(max_total_tokens // 16 + 1) * batch_size, block_size=16 num_blocks=((max_total_tokens + 15) // 16) * batch_size, block_size=16
) )
else: else:
batch_size = 1 if prompts is str else len(prompts) batch_size = 1 if prompts is str else len(prompts)
...@@ -242,6 +274,8 @@ if __name__ == "__main__": ...@@ -242,6 +274,8 @@ if __name__ == "__main__":
device_str = "cpu" device_str = "cpu"
elif args.nvidia: elif args.nvidia:
device_str = "cuda" device_str = "cuda"
elif args.qy:
device_str = "cuda"
elif args.metax: elif args.metax:
device_str = "cuda" device_str = "cuda"
elif args.moore: elif args.moore:
...@@ -250,11 +284,13 @@ if __name__ == "__main__": ...@@ -250,11 +284,13 @@ if __name__ == "__main__":
device_str = "cuda" device_str = "cuda"
elif args.cambricon: elif args.cambricon:
device_str = "mlu" device_str = "mlu"
elif args.ali:
device_str = "cuda"
elif args.hygon: elif args.hygon:
device_str = "cuda" device_str = "cuda"
else: else:
print( print(
"Usage: python examples/jiuge.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon | --hygon] --model_path=<path/to/model_dir>\n" "Usage: python examples/jiuge.py [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --cambricon | --ali | --hygon] --model_path=<path/to/model_dir>\n"
"such as, python examples/jiuge.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0" "such as, python examples/jiuge.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
) )
sys.exit(1) sys.exit(1)
...@@ -265,6 +301,7 @@ if __name__ == "__main__": ...@@ -265,6 +301,7 @@ if __name__ == "__main__":
backend = args.backend backend = args.backend
tp = args.tp tp = args.tp
enable_paged_attn = args.enable_paged_attn enable_paged_attn = args.enable_paged_attn
enable_graph = args.enable_graph
if backend != "cpp": if backend != "cpp":
raise ValueError(f"Unsupported backend: {backend}.") raise ValueError(f"Unsupported backend: {backend}.")
...@@ -277,6 +314,7 @@ if __name__ == "__main__": ...@@ -277,6 +314,7 @@ if __name__ == "__main__":
infini_device=infini_device, infini_device=infini_device,
tp=tp, tp=tp,
enable_paged_attn=enable_paged_attn, enable_paged_attn=enable_paged_attn,
enable_graph=enable_graph,
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p, top_p=args.top_p,
temperature=args.temperature, temperature=args.temperature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment