Unverified Commit a4ced800 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #205 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 96ecf490 04c37f3f
#pragma once
#include "../cache/cache.hpp"
#include "../config/model_config.hpp"
#include "../models/model_factory.hpp"
#include "compiler/general_compiler.hpp"
#include "distributed/distributed.hpp"
#include "rank_barrier.hpp"
#include <any>
#include <condition_variable>
......@@ -20,6 +23,7 @@ class RankWorker {
LOAD,
RUN,
RESET_CACHE,
COMPILE,
STOP
};
......@@ -55,7 +59,15 @@ public:
RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config);
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling);
RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling);
// Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name,
......@@ -70,6 +82,9 @@ public:
// Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig *new_config);
// Compile the model graph if enabled.
void compile();
// Wait until run job completes. The result can be retrieved with get_output().
void wait();
......@@ -86,11 +101,16 @@ private:
private:
// Worker properties
const InfinilmModel::Config &model_config_;
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::Cache> cache_;
// Graph Compiling
bool enable_graph_compiling_;
std::unique_ptr<GraphCompiler> compiler_;
// Command for the pending job (protected by mutex_)
Command job_cmd_;
......@@ -116,6 +136,8 @@ private:
// Random
std::mt19937 rng_;
RankBarrier *barrier_;
};
} // namespace infinilm::engine
......@@ -6,6 +6,18 @@ namespace infinilm::layers {
// ---------------------------------------------------------
// QKV Parallel Linear
// ---------------------------------------------------------
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t head_dim,
size_t num_q_head,
......@@ -57,6 +69,61 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
}
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t head_dim,
size_t num_q_head,
size_t num_kv_head,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
bool bias,
const infinicore::DataType &dtype,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: QKVParallelLinear(hidden_size,
head_dim, head_dim, head_dim,
num_q_head, num_kv_head, num_kv_head,
bias, bias, bias,
quantization,
dtype, device, rank_info) {}
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t q_dim, size_t k_dim, size_t v_dim,
size_t num_q_head, size_t num_k_head, size_t num_v_head,
bool q_bias, bool k_bias, bool v_bias,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: infinicore::nn::ColumnParallelLinear(
hidden_size,
num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
quantization,
(q_bias || k_bias || v_bias),
dtype,
device,
rank_info.tp_rank,
rank_info.tp_size),
q_dim_(q_dim),
k_dim_(k_dim),
v_dim_(v_dim),
num_q_head_(num_q_head),
num_k_head_(num_k_head),
num_v_head_(num_v_head),
q_bias_(q_bias),
k_bias_(k_bias),
v_bias_(v_bias) {
if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) {
throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size");
}
if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) {
throw std::runtime_error("q_bias, k_bias, v_bias must all match");
}
q_out_size_ = num_q_head_ * q_dim_ / tp_size_;
k_out_size_ = num_k_head_ * k_dim_ / tp_size_;
v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
}
std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
QKVParallelLinear::forward_split(infinicore::Tensor &input) {
auto output = this->forward(input);
......@@ -86,6 +153,40 @@ infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const {
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale() const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_weight_scale() const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{0, q_out_size_, k_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale() const {
return infinicore::nn::Parameter(
weight_scale_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_zeros() const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_k_weight_zeros() const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{0, q_out_size_, k_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_v_weight_zeros() const {
return infinicore::nn::Parameter(
weight_zeros_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter QKVParallelLinear::get_q_bias() const {
if (!q_bias_) {
return infinicore::nn::Parameter();
......@@ -120,6 +221,18 @@ bool QKVParallelLinear::has_v_bias() const { return v_bias_; }
// ---------------------------------------------------------
// Gate-Up Parallel Linear
// ---------------------------------------------------------
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias,
const infinicore::DataType &dtype, const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
......@@ -135,6 +248,22 @@ GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermedia
}
}
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias,
const infinicore::DataType &dtype, const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, quantization, dtype, device, rank_info) {
}
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype, const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, quantization, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) {
if (gate_bias_ != up_bias_) {
throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time");
}
}
std::tuple<infinicore::Tensor, infinicore::Tensor> GateUpParallelLinear::forward_split(infinicore::Tensor &input) {
auto output = this->forward(input);
auto cols = output->shape()[2];
......@@ -168,6 +297,22 @@ infinicore::nn::Parameter GateUpParallelLinear::get_up_bias() const {
}
}
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_scale() const {
return infinicore::nn::Parameter(weight_scale_->narrow({{0, 0, weight_scale_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_scale() const {
return infinicore::nn::Parameter(weight_scale_->narrow({{0, weight_scale_->size(0) / 2, weight_scale_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_zeros() const {
return infinicore::nn::Parameter(weight_zeros_->narrow({{0, 0, weight_zeros_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_zeros() const {
return infinicore::nn::Parameter(weight_zeros_->narrow({{0, weight_zeros_->size(0) / 2, weight_zeros_->size(0) / 2}}), 0, tp_rank_, tp_size_);
}
bool GateUpParallelLinear::has_gate_bias() const {
return gate_bias_;
}
......
#pragma once
#include "infinicore/nn/linear.hpp"
#include "infinicore/quantization.hpp"
#include "../engine/distributed/communication_group.hpp"
......@@ -23,6 +24,25 @@ public:
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
explicit QKVParallelLinear(size_t hidden_size,
size_t q_dim, size_t k_dim, size_t v_dim,
size_t num_q_head, size_t num_k_head, size_t num_v_head,
bool q_bias, bool k_bias, bool v_bias,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
// A more common case where all heads have the same dimension
explicit QKVParallelLinear(size_t hidden_size,
size_t head_dim,
size_t num_q_head, size_t num_kv_head,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
bool bias = false,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
forward_split(infinicore::Tensor &input);
......@@ -30,6 +50,14 @@ public:
infinicore::nn::Parameter get_k_weight() const;
infinicore::nn::Parameter get_v_weight() const;
infinicore::nn::Parameter get_q_weight_scale() const;
infinicore::nn::Parameter get_k_weight_scale() const;
infinicore::nn::Parameter get_v_weight_scale() const;
infinicore::nn::Parameter get_q_weight_zeros() const;
infinicore::nn::Parameter get_k_weight_zeros() const;
infinicore::nn::Parameter get_v_weight_zeros() const;
infinicore::nn::Parameter get_q_bias() const;
infinicore::nn::Parameter get_k_bias() const;
infinicore::nn::Parameter get_v_bias() const;
......@@ -55,6 +83,18 @@ private:
class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {
public:
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias = false,
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
......@@ -63,14 +103,33 @@ public:
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
bool bias = false,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
std::tuple<infinicore::Tensor, infinicore::Tensor> forward_split(infinicore::Tensor &input);
infinicore::nn::Parameter get_gate_weight() const;
infinicore::nn::Parameter get_gate_weight_scale() const;
infinicore::nn::Parameter get_gate_weight_zeros() const;
infinicore::nn::Parameter get_gate_bias() const;
infinicore::nn::Parameter get_up_weight() const;
infinicore::nn::Parameter get_up_weight_scale() const;
infinicore::nn::Parameter get_up_weight_zeros() const;
infinicore::nn::Parameter get_up_bias() const;
bool has_gate_bias() const;
......@@ -103,4 +162,62 @@ private:
if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
// ========================= QKV Quantization ==================================
#define INFINILM_QKV_LINEAR_W8A8_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \
this->register_parameter(std::string(q_name) + ".weight_scale", name##_->get_q_weight_scale()); \
this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \
this->register_parameter(std::string(k_name) + ".weight_scale", name##_->get_k_weight_scale()); \
this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight()); \
this->register_parameter(std::string(v_name) + ".weight_scale", name##_->get_v_weight_scale()); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
if (name##_->has_v_bias()) \
this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());
#define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \
name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight()); \
this->register_parameter(std::string(q_name) + ".qzeros", name##_->get_q_weight_zeros()); \
this->register_parameter(std::string(q_name) + ".scales", name##_->get_q_weight_scale()); \
this->register_parameter(std::string(k_name) + ".qweight", name##_->get_k_weight()); \
this->register_parameter(std::string(k_name) + ".qzeros", name##_->get_k_weight_zeros()); \
this->register_parameter(std::string(k_name) + ".scales", name##_->get_k_weight_scale()); \
this->register_parameter(std::string(v_name) + ".qweight", name##_->get_v_weight()); \
this->register_parameter(std::string(v_name) + ".qzeros", name##_->get_v_weight_zeros()); \
this->register_parameter(std::string(v_name) + ".scales", name##_->get_v_weight_scale()); \
if (name##_->has_q_bias()) \
this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
if (name##_->has_k_bias()) \
this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
if (name##_->has_v_bias()) \
this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());
// ========================= Gate-Up Quantization ==============================
#define INFINILM_GATE_UP_LINEAR_W8A8_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \
this->register_parameter(std::string(gate_name) + ".weight_scale", name##_->get_gate_weight_scale()); \
this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight()); \
this->register_parameter(std::string(up_name) + ".weight_scale", name##_->get_up_weight_scale()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
#define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \
name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__); \
this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight()); \
this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale()); \
this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros()); \
this->register_parameter(std::string(up_name) + ".qweight", name##_->get_up_weight()); \
this->register_parameter(std::string(up_name) + ".scales", name##_->get_up_weight_scale()); \
this->register_parameter(std::string(up_name) + ".qzeros", name##_->get_up_weight_zeros()); \
if (name##_->has_gate_bias()) \
this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
if (name##_->has_up_bias()) \
this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());
} // namespace infinilm::layers
#pragma once
#include "infinicore/nn/module.hpp"
#include "../cache/cache.hpp"
#include "infinicore/nn/module.hpp"
#include "nlohmann/json.hpp"
#include <any>
......@@ -13,7 +13,6 @@ class InfinilmModel : public infinicore::nn::Module {
public:
struct Config {
std::string model_type;
virtual ~Config() = default;
};
......@@ -43,5 +42,6 @@ public:
virtual Output forward(const Input &input) const = 0;
virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
virtual const cache::CacheConfig *get_cache_config() const = 0;
};
} // namespace infinilm
......@@ -16,9 +16,9 @@
* - LlamaForCausalLM: Complete model with language modeling head
*/
#include "llama_config.hpp"
#include "../../config/model_config.hpp"
#include "llama_attention.hpp"
#include "llama_mlp.hpp"
#include "llama_decoder_layer.hpp"
#include "llama_model.hpp"
#include "llama_for_causal_lm.hpp"
#include "llama_mlp.hpp"
#include "llama_model.hpp"
......@@ -9,7 +9,6 @@
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <optional>
#include <spdlog/spdlog.h>
#include <stdexcept>
......@@ -17,6 +16,18 @@
namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaAttention::LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
......@@ -61,6 +72,65 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
}
}
LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info)
: model_config_(model_config),
layer_idx_(layer_idx),
hidden_size_(model_config->get<size_t>("hidden_size")),
num_attention_heads_(model_config->get<size_t>("num_attention_heads")),
num_key_value_heads_(model_config->get<size_t>("num_key_value_heads")),
head_dim_(model_config->get_head_dim()),
kv_dim_(model_config->get_kv_dim()),
use_bias_(model_config->get_or<bool>("attention_bias", true)),
use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)),
max_position_embeddings_(model_config->get<size_t>("max_position_embeddings")),
rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
int num_attention_heads = model_config_->get<size_t>("num_attention_heads");
int num_key_value_heads = model_config_->get<size_t>("num_key_value_heads");
if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
this->num_attention_heads_ = num_attention_heads / tp_size;
this->num_key_value_heads_ = num_key_value_heads / tp_size;
} else {
throw std::runtime_error("num_attention_heads / tp_size error.");
}
scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));
auto quant_scheme = this->model_config_->get_quant_scheme();
switch (quant_scheme) {
case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8:
INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info);
INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
case infinicore::quantization::QuantScheme::AWQ_W4A16:
INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info);
INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
default:
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info);
INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
}
if (model_config_->get<std::string>("model_type") == "qwen3") {
INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
}
}
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
......@@ -75,7 +145,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
// 1. Project Q, K, V
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
if (use_qk_norm_) {
if (use_qk_norm_ || model_config_->get_or<std::string>("model_type", "None") == "qwen3") {
q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_}));
k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_}));
}
......@@ -112,8 +182,8 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim]
auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor k_total; // [bs, n_kv_head, max_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, max_seq_len, head_dim]
if (kv_cache == nullptr) {
k_total = k_permuted;
v_total = v_permuted;
......@@ -124,7 +194,18 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
} else {
throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
}
auto total_seq_len = k_total->shape()[2];
infinicore::Tensor attn_output;
if (false) {
// experimental nineoothed flash attention
attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scaling_, true);
attn_output = attn_output->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
} else {
size_t total_seq_len = reinterpret_cast<int64_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
// 6. Compute attention
size_t ngroup = num_attention_heads_ / num_key_value_heads_;
......@@ -141,10 +222,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
auto attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
->permute({0, 2, 1, 3})
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
}
auto output = o_proj_->forward(attn_output);
......@@ -184,7 +266,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});
if (use_qk_norm_) {
if (use_qk_norm_ || model_config_->get_or<std::string>("model_type", "None") == "qwen3") {
q_reshaped = q_norm_->forward(q_reshaped);
k_reshaped = k_norm_->forward(k_reshaped);
}
......
#pragma once
#include "../../cache/kv_cache.hpp"
#include "../../config/model_config.hpp"
#include "../../engine/distributed/distributed.hpp"
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"
......@@ -36,11 +37,28 @@ public:
* @param layer_idx Layer index for cache access
* @param dtype Optional data type for model parameters (defaults to F32)
*/
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: compute attention
*
......@@ -101,6 +119,7 @@ protected:
std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;
private:
std::shared_ptr<infinilm::config::ModelConfig> model_config_ = std::make_shared<infinilm::config::ModelConfig>();
size_t layer_idx_; // Layer index for cache access
size_t hidden_size_;
size_t num_attention_heads_;
......@@ -109,7 +128,7 @@ private:
size_t kv_dim_;
bool use_bias_; // Bias for Q/K/V projections
bool use_output_bias_; // Bias for output projection (o_proj)
bool use_qk_norm_; // Whether to use QK RMSNorm
bool use_qk_norm_ = false; // Whether to use QK RMSNorm
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
float scaling_;
......
#include "llama_decoder_layer.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/ops.hpp"
#include <optional>
namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
......@@ -23,7 +33,25 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
}
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
// Initialize layer normalization layers
INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
dtype, device);
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
dtype, device);
// Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_);
INFINICORE_NN_MODULE_INIT(mlp, model_config_, device, rank_info_);
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states,
infinicore::Tensor &residual,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
......@@ -31,30 +59,19 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
std::optional<infinicore::Tensor> input_offsets,
std::optional<infinicore::Tensor> block_tables,
std::optional<infinicore::Tensor> slot_mapping) const {
// Save residual for attention
auto residual = hidden_states;
// 1. Attention layer normalization
input_layernorm_->forward_inplace(hidden_states, residual);
// 1. Pre-attention layer normalization
auto normed_states = input_layernorm_->forward(hidden_states);
// 2. Self-attention with residual connection
auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// Add residual: hidden_states = hidden_states + attn_output
auto output = infinicore::op::add(residual, attn_output);
// Save residual for MLP
residual = output;
// 2. Self-attention
hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
// 3. Post-attention layer normalization
normed_states = post_attention_layernorm_->forward(output);
// 4. MLP with residual connection
auto mlp_output = mlp_->forward(normed_states);
post_attention_layernorm_->forward_inplace(hidden_states, residual);
// Add residual: output = output + mlp_output
output = infinicore::op::add(residual, mlp_output);
// 4. MLP
hidden_states = mlp_->forward(hidden_states);
return output;
return std::make_tuple(hidden_states, residual);
}
} // namespace infinilm::models::llama
......@@ -33,20 +33,41 @@ public:
* @param layer_idx Layer index for cache management and debugging
* @param dtype Optional data type for model parameters (defaults to F32)
*/
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
size_t layer_idx,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: process one decoder layer
*
* @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
* @param hidden_states [batch, seq_len, hidden_size], will be modified
* @param residual [batch, seq_len, hidden_size], will be modified
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional KV cache for incremental decoding
* @return Output tensor of shape [batch, seq_len, hidden_size]
* Updated residual tensor of shape [batch, seq_len, hidden_size]
*/
infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
std::tuple<infinicore::Tensor, infinicore::Tensor>
forward(infinicore::Tensor &hidden_states,
infinicore::Tensor &residual,
const infinicore::Tensor &position_ids,
std::shared_ptr<infinilm::cache::Cache> kv_cache,
std::optional<infinicore::Tensor> past_sequence_lengths,
......@@ -75,6 +96,7 @@ protected:
INFINICORE_NN_MODULE(LlamaAttention, self_attn);
INFINICORE_NN_MODULE(LlamaMLP, mlp);
engine::distributed::RankInfo rank_info_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
private:
size_t layer_idx_; // Layer index for cache management and debugging
......
......@@ -2,19 +2,26 @@
#include "infinicore/context/context.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp"
#include <iostream>
namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info) {
// Initialize module's device_ member
device_ = device;
const auto &dtype{config.dtype};
// Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, rank_info);
......@@ -25,6 +32,24 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
dtype, device);
}
LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info) {
// Initialize module's device_ member
device_ = device;
const auto &dtype{model_config->get_dtype()};
// Initialize base model
INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info);
// Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens
// For now, we create a separate linear layer
INFINICORE_NN_MODULE_INIT(lm_head, model_config->get<size_t>("hidden_size"), model_config->get<size_t>("vocab_size"), false,
dtype, device);
}
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto input_ids = input.input_ids.value();
auto position_ids = input.position_ids.value();
......@@ -40,12 +65,16 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
return {logits};
}
void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) {
model_->reset_cache(cache_config);
cache_config_ = cache_config->unique_copy();
model_->reset_cache(cache_config_.get());
}
const cache::CacheConfig *LlamaForCausalLM::get_cache_config() const {
return cache_config_.get();
}
} // namespace infinilm::models::llama
......@@ -28,10 +28,26 @@ public:
* @param config Model configuration
* @param device Device to create tensors on
*/
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: compute language modeling logits
*
......@@ -42,8 +58,9 @@ public:
void reset_cache(const cache::CacheConfig *cache_config) override;
const cache::CacheConfig *get_cache_config() const override;
// Module information
const LlamaConfig &config() const { return model_->config(); }
LlamaModel &model() { return *model_; }
const LlamaModel &model() const { return *model_; }
......@@ -53,6 +70,8 @@ protected:
// Language modeling head
INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head);
std::unique_ptr<cache::CacheConfig> cache_config_;
};
} // namespace infinilm::models::llama
......@@ -3,7 +3,18 @@
#include "infinicore/ops.hpp"
namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaMLP::LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
......@@ -22,6 +33,43 @@ LlamaMLP::LlamaMLP(const LlamaConfig &config,
dtype, device, tp_rank, tp_size, rank_info.comm);
}
LlamaMLP::LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: model_config_(model_config), hidden_size_(model_config->get<size_t>("hidden_size")),
intermediate_size_(model_config->get<size_t>("intermediate_size")),
use_bias_(model_config->get_or<bool>("mlp_bias", false)), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
// Initialize projection layers
auto quant_scheme = this->model_config_->get_quant_scheme();
switch (quant_scheme) {
case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8:
INFINILM_GATE_UP_LINEAR_W8A8_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
case infinicore::quantization::QuantScheme::AWQ_W4A16:
INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
default:
INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, rank_info_);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
dtype, device, tp_rank, tp_size, rank_info.comm);
break;
}
}
infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
// 1. Project to gate and up
auto hidden_states_mutable = hidden_states;
......
......@@ -3,6 +3,7 @@
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"
#include "../../config/model_config.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
......@@ -33,10 +34,26 @@ public:
* @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32)
*/
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: compute MLP output
*
......@@ -57,6 +74,8 @@ protected:
size_t hidden_size_;
size_t intermediate_size_;
bool use_bias_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
};
} // namespace infinilm::models::llama
......@@ -6,7 +6,18 @@
#include <iostream>
namespace infinilm::models::llama {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
......@@ -43,6 +54,39 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
}
}
LlamaModel::LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: model_config_(model_config), rank_info_(rank_info) {
const auto &dtype{model_config_->get_dtype()};
// Initialize token embeddings
INFINICORE_NN_MODULE_INIT(embed_tokens, model_config_->get<size_t>("vocab_size"), model_config_->get<size_t>("hidden_size"),
std::nullopt, dtype, device);
// Initialize decoder layers with layer indices
// TODO: Update INFINICORE_NN_MODULE_VEC_INIT macro to support per-layer constructor arguments
// (e.g., via a factory function or lambda that receives the layer index)
// Currently, we can't use the macro because each layer needs a different layer_idx
layers_.reserve(model_config_->get<size_t>("num_hidden_layers"));
for (size_t i = 0; i < model_config_->get<size_t>("num_hidden_layers"); ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), model_config_, device, i, rank_info));
}
// Initialize final layer normalization
INFINICORE_NN_MODULE_INIT(norm, model_config_->get<size_t>("hidden_size"), model_config_->get<double>("rms_norm_eps"),
dtype, device);
// Initialize Rotary Position Embeddings (shared across all layers)
// Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing
INFINICORE_NN_MODULE_INIT(rotary_emb, model_config_->get_head_dim(), model_config_->get<size_t>("max_position_embeddings"),
model_config_->get<double>("rope_theta"), infinicore::nn::RoPE::Algo::GPT_NEOX,
dtype, device, model_config_->get_rope_scaling());
for (auto &layer : layers_) {
if (layer) {
layer->set_rotary_emb(rotary_emb_);
}
}
}
infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
std::optional<infinicore::Tensor> past_sequence_lengths,
......@@ -55,11 +99,23 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// 2. Process through all decoder layers
size_t num_layers = layers_.size();
infinicore::Tensor residual;
for (size_t i = 0; i < num_layers; ++i) {
hidden_states = layers_.at(i)->forward(hidden_states, position_ids, kv_cache_, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping);
layers_.at(i)->forward(
hidden_states,
residual,
position_ids,
kv_cache_,
past_sequence_lengths,
total_sequence_lengths,
input_offsets,
block_tables,
slot_mapping);
}
return norm_->forward(hidden_states);
norm_->forward_inplace(hidden_states, residual);
return hidden_states;
}
void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
......@@ -67,7 +123,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
kv_cache_ = nullptr;
return;
}
if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config);
kv_cache_config && model_config_ == nullptr) {
kv_cache_ = std::make_shared<cache::StaticKVCache>(
config_.head_dim,
config_.head_dim,
......@@ -78,8 +135,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
config_.dtype,
*kv_cache_config,
rank_info_);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) {
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
paged_kv_cache_config && model_config_ == nullptr) {
kv_cache_ = std::make_shared<cache::PagedKVCache>(
config_.head_dim,
config_.head_dim,
......@@ -89,6 +146,27 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
config_.dtype,
*paged_kv_cache_config,
rank_info_);
} else if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
kv_cache_ = std::make_shared<cache::StaticKVCache>(
model_config_->get_head_dim(),
model_config_->get_head_dim(),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_hidden_layers"),
model_config_->get<size_t>("max_position_embeddings"),
model_config_->get_dtype(),
*kv_cache_config,
rank_info_);
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) {
kv_cache_ = std::make_shared<cache::PagedKVCache>(
model_config_->get_head_dim(),
model_config_->get_head_dim(),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_key_value_heads"),
model_config_->get<size_t>("num_hidden_layers"),
model_config_->get_dtype(),
*paged_kv_cache_config,
rank_info_);
} else {
throw std::runtime_error("Unsupported cache type");
}
......
#pragma once
#include "../../cache/kv_cache.hpp"
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp"
#include "infinicore/nn/embedding.hpp"
......@@ -38,10 +37,26 @@ public:
* @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32)
*/
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
LlamaModel(const LlamaConfig &config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
LlamaModel(std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: process input through the model
*
......@@ -64,8 +79,7 @@ public:
void reset_cache(const cache::CacheConfig *cache_config);
// Module information
const LlamaConfig &config() const { return config_; }
size_t num_layers() const { return config_.num_hidden_layers; }
size_t num_layers() const { return model_config_->get<size_t>("num_hidden_layers"); }
protected:
// Token embeddings
......@@ -86,6 +100,8 @@ protected:
private:
LlamaConfig config_;
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
};
} // namespace infinilm::models::llama
......@@ -2,11 +2,22 @@
#include "llama/llama.hpp"
namespace infinilm {
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) {
std::shared_ptr<InfinilmModel> model;
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr;
......@@ -22,4 +33,24 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
return model;
}
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info,
const cache::CacheConfig *cache) {
std::shared_ptr<InfinilmModel> model;
if (true) {
model = std::make_shared<models::llama::LlamaForCausalLM>(
model_config, rank_info.device, rank_info);
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
if (cache) {
model->reset_cache(cache);
}
return model;
}
} // namespace infinilm
#pragma once
#include "../config/model_config.hpp"
#include "infinilm_model.hpp"
#include "../engine/distributed/distributed.hpp"
......@@ -7,9 +8,26 @@
namespace infinilm {
class InfinilmModelFactory {
public:
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
static std::shared_ptr<InfinilmModel> createModel(
const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr);
static std::shared_ptr<InfinilmModel> createModel(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr);
};
} // namespace infinilm
......@@ -35,17 +35,20 @@ inline void bind_infer_engine(py::module &m) {
const InfinilmModel::Config &cfg,
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg) {
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) {
return std::make_shared<InferEngine>(
cfg,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr);
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling);
}),
py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none())
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false)
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
......@@ -60,20 +63,52 @@ inline void bind_infer_engine(py::module &m) {
}
return state_dict_tp_all;
})
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) {
self.reset_cache(cfg ? cfg.get() : nullptr);
},
py::arg("cache_config") = py::none())
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy()));
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
infer_engine
.def(py::init([](
const std::string &model_path,
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) {
return std::make_shared<InferEngine>(
model_path,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling);
}),
py::arg("model_path") = "",
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false)
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) {
py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) {
py::dict result;
for (const auto &[name, param] : state_dict_tp) {
result[py::cast(name)] = infinicore::Tensor(param);
}
state_dict_tp_all.append(result);
}
return state_dict_tp_all;
})
.def("__repr__", [](const InferEngine &self) {
return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
});
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
.def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
py::class_<InferEngine::Input>(infer_engine, "Input")
.def(
......
......@@ -3,7 +3,7 @@ from transformers import AutoTokenizer
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
import argparse
import sys
import time
......@@ -137,11 +137,36 @@ def get_args():
action="store_true",
help="Run nvidia test",
)
parser.add_argument(
"--qy",
action="store_true",
help="Run qy test",
)
parser.add_argument(
"--metax",
action="store_true",
help="Run metax test",
)
parser.add_argument(
"--moore",
action="store_true",
help="Run moore test",
)
parser.add_argument(
"--iluvatar",
action="store_true",
help="Run iluvatar test",
)
parser.add_argument(
"--cambricon",
action="store_true",
help="Run cambricon test",
)
parser.add_argument(
"--ali",
action="store_true",
help="Run alippu test",
)
parser.add_argument(
"--model",
type=str,
......@@ -199,7 +224,21 @@ def get_args():
default=1.0,
help="sampling temperature",
)
parser.add_argument(
"--enable-paged-attn",
action="store_true",
help="use paged cache",
)
parser.add_argument(
"--enable-graph",
action="store_true",
help="enable graph compiling",
)
parser.add_argument(
"--warmup",
action="store_true",
help="Perform a warmup run before benchmarking/inference."
)
return parser.parse_args()
......@@ -223,6 +262,8 @@ class TestModel:
infini_device=infinicore.device("cpu", 0),
tp=1,
skip_load=False,
cache_config=None,
enable_graph=False,
) -> None:
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -232,6 +273,8 @@ class TestModel:
model_path,
device=infini_device,
distributed_config=DistConfig(tp),
cache_config=cache_config,
enable_graph_compiling=enable_graph,
)
# ---------------------------------------------------------------------------- #
......@@ -245,6 +288,13 @@ class TestModel:
# ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# ---------------------------------------------------------------------------- #
# token编码
# ---------------------------------------------------------------------------- #
......@@ -257,7 +307,16 @@ class TestModel:
]
# print(input_content, end="", flush=True)
input_ids_list = tokenizer.batch_encode_plus(input_content)["input_ids"]
# Support Transformers >= 5.0 for batch_encode_plus deprecation
encoding = tokenizer(
input_content,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt"
)
input_ids_list = encoding["input_ids"]
self.model = model
self.tokenizer = tokenizer
......@@ -315,8 +374,18 @@ if __name__ == "__main__":
device_str = "cpu"
elif args.nvidia:
device_str = "cuda"
elif args.qy:
device_str = "cuda"
elif args.metax:
device_str = "cuda"
elif args.moore:
device_str = "musa"
elif args.iluvatar:
device_str = "cuda"
elif args.cambricon:
device_str = "mlu"
elif args.ali:
device_str = "cuda"
else:
print(
"python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50"
......@@ -336,6 +405,8 @@ if __name__ == "__main__":
batch_size = args.batch_size
input_len = args.input_len
output_len = args.output_len
enable_paged_attn = args.enable_paged_attn
enable_graph = args.enable_graph
if isinstance(batch_size, int):
batch_size = [batch_size]
......@@ -350,15 +421,81 @@ if __name__ == "__main__":
# -------------------------------------------------------- #
# 测试
# -------------------------------------------------------- #
# print("=================== start test ====================", type(batch_size))
if enable_paged_attn:
paged_kv_block_size = 16
max_num_blocks = max(
[
((c_["input_len"] + c_["output_len"] + 15) // 16) * c_["batch_size"]
for _, c_ in cases_dict.items()
]
)
cache_config = PagedKVCacheConfig(max_num_blocks, paged_kv_block_size)
else:
cache_config = None
test = TestModel(
model_path,
infini_device=infini_device,
tp=tp,
skip_load=skip_load,
cache_config=cache_config,
enable_graph=enable_graph,
)
# ---------------------------------------------------------------------------- #
# Warmup
# ---------------------------------------------------------------------------- #
if args.warmup:
warmup_steps = 1
# warmup cache capacity
warmup_cache_len = 128
warmup_batch = len(test.input_ids_list)
test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=warmup_batch,
max_cache_len=warmup_cache_len,
)
)
avg_prompt_len = min(
64,
max(len(ids) for ids in test.input_ids_list)
)
warmup_ids = [
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
for ids in test.input_ids_list
]
input_ids_infini = infinicore.from_list(warmup_ids)
print("=================== warmup start ===================")
for _ in range(warmup_steps):
_ = test.model.generate(
input_ids_infini,
GenerationConfig(
max_new_tokens=5, # decode kernel warmup
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
),
_measure_and_log_time=False,
)
print("=================== warmup done ====================")
# reset cache back to benchmark config
if cache_config is not None:
test.model.reset_cache(cache_config)
# ---------------------------------------------------------------------------- #
# Warmup done
# ---------------------------------------------------------------------------- #
for idx, case in tqdm(cases_dict.items(), desc="Processing cases"):
tqdm.write(f"\033[92mProcessing : {case}\033[0m")
......@@ -366,7 +503,8 @@ if __name__ == "__main__":
input_len = case["input_len"]
output_len = case["output_len"]
# reset cache for each case
if not enable_paged_attn:
# reset cache if static kvcache is used
initial_capacity = input_len + output_len
test.model.reset_cache(
StaticKVCacheConfig(
......
import infinicore
import transformers
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
from infinilm.modeling_utils import load_model_state_dict_by_file
......@@ -10,6 +11,7 @@ import time
import os
import numpy as np
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from packaging import version
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
......@@ -27,6 +29,11 @@ def get_args():
action="store_true",
help="Run nvidia test",
)
parser.add_argument(
"--qy",
action="store_true",
help="Run qy test",
)
parser.add_argument(
"--metax",
action="store_true",
......@@ -47,6 +54,11 @@ def get_args():
action="store_true",
help="Run cambricon test",
)
parser.add_argument(
"--ali",
action="store_true",
help="Run alippu test",
)
parser.add_argument(
"--hygon",
action="store_true",
......@@ -93,6 +105,11 @@ def get_args():
action="store_true",
help="use paged cache",
)
parser.add_argument(
"--enable-graph",
action="store_true",
help="enable graph compiling",
)
parser.add_argument(
"--top-k",
......@@ -125,6 +142,7 @@ def test(
infini_device=infinicore.device("cpu", 0),
tp=1,
enable_paged_attn=False,
enable_graph=False,
top_k=1,
top_p=1.0,
temperature=1.0,
......@@ -137,8 +155,8 @@ def test(
model_path,
device=infini_device,
distributed_config=DistConfig(tp),
enable_graph_compiling=enable_graph,
)
# ---------------------------------------------------------------------------- #
# Load Weights
# ---------------------------------------------------------------------------- #
......@@ -148,7 +166,6 @@ def test(
# create tokenizer
# ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if "llama" == model.config.model_type:
backend = getattr(tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend)
......@@ -182,9 +199,24 @@ def test(
for prompt in prompts
]
input_ids_list = tokenizer.batch_encode_plus(input_contents)[
"input_ids"
] # List: [[1, 1128, 526, 366, 29892]]
# input_ids_list = tokenizer.batch_encode_plus(input_contents)[
# "input_ids"
# ] # List: [[1, 1128, 526, 366, 29892]]
if version.parse(transformers.__version__) < version.parse("5.0.0"):
# Ideally this is solved by upgrading transformers. However, doing so causes version mismatch between transformers and mlu pytorch on devices with Phytium CPU. So a branch is temporarily used.
input_ids_list = [
tokenizer.encode_plus(
text, truncation=True, max_length=2048, add_special_tokens=True
)["input_ids"]
for text in input_contents
]
else:
input_ids_list = [
tokenizer._encode_plus(
text, truncation=True, max_length=2048, add_special_tokens=True
)["input_ids"]
for text in input_contents
]
# ---------------------------------------------------------------------------- #
# Create KVCache
......@@ -193,7 +225,7 @@ def test(
batch_size = 1 if prompts is str else len(prompts)
max_total_tokens = max_new_tokens + len(input_ids_list[0])
cache_config = PagedKVCacheConfig(
num_blocks=(max_total_tokens // 16 + 1) * batch_size, block_size=16
num_blocks=((max_total_tokens + 15) // 16) * batch_size, block_size=16
)
else:
batch_size = 1 if prompts is str else len(prompts)
......@@ -242,6 +274,8 @@ if __name__ == "__main__":
device_str = "cpu"
elif args.nvidia:
device_str = "cuda"
elif args.qy:
device_str = "cuda"
elif args.metax:
device_str = "cuda"
elif args.moore:
......@@ -250,11 +284,13 @@ if __name__ == "__main__":
device_str = "cuda"
elif args.cambricon:
device_str = "mlu"
elif args.ali:
device_str = "cuda"
elif args.hygon:
device_str = "cuda"
else:
print(
"Usage: python examples/jiuge.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon | --hygon] --model_path=<path/to/model_dir>\n"
"Usage: python examples/jiuge.py [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --cambricon | --ali | --hygon] --model_path=<path/to/model_dir>\n"
"such as, python examples/jiuge.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
)
sys.exit(1)
......@@ -265,6 +301,7 @@ if __name__ == "__main__":
backend = args.backend
tp = args.tp
enable_paged_attn = args.enable_paged_attn
enable_graph = args.enable_graph
if backend != "cpp":
raise ValueError(f"Unsupported backend: {backend}.")
......@@ -277,6 +314,7 @@ if __name__ == "__main__":
infini_device=infini_device,
tp=tp,
enable_paged_attn=enable_paged_attn,
enable_graph=enable_graph,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment