"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "1ec9268dfd09dbefa24f4a049f04b79ebdac1905"
Commit 39790c1e authored by Your Name's avatar Your Name
Browse files

issue/111 - 9g7b分布式

parent 9c256a17
...@@ -9,12 +9,15 @@ ...@@ -9,12 +9,15 @@
#include <iostream> #include <iostream>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
#include <vector>
namespace infinilm::models::llama { namespace infinilm::models::llama {
LlamaAttention::LlamaAttention(const LlamaConfig &config, const infinicore::Device &device, LlamaAttention::LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
infinicore::DataType dtype) infinicore::DataType dtype,
engine::distributed::RankInfo rank_info)
: layer_idx_(layer_idx), : layer_idx_(layer_idx),
hidden_size_(config.hidden_size), hidden_size_(config.hidden_size),
num_attention_heads_(config.num_attention_heads), num_attention_heads_(config.num_attention_heads),
...@@ -23,17 +26,31 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, const infinicore::Devi ...@@ -23,17 +26,31 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, const infinicore::Devi
kv_dim_(config.kv_dim()), kv_dim_(config.kv_dim()),
use_bias_(config.attention_bias), use_bias_(config.attention_bias),
use_output_bias_(config.attention_output_bias), use_output_bias_(config.attention_output_bias),
max_position_embeddings_(config.max_position_embeddings) { max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) {
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
int num_attention_heads = config.num_attention_heads;
int num_key_value_heads = config.num_key_value_heads;
if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
this->num_attention_heads_ = num_attention_heads / tp_size;
this->num_key_value_heads_ = num_key_value_heads / tp_size;
} else {
throw std::runtime_error("num_attention_heads / tp_size error.");
}
// Initialize projection layers // Initialize projection layers
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, hidden_size_, use_bias_, INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, hidden_size_, use_bias_,
dtype, device); dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(k_proj, hidden_size_, kv_dim_, use_bias_, INFINICORE_NN_MODULE_INIT(k_proj, hidden_size_, kv_dim_, use_bias_,
dtype, device); dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(v_proj, hidden_size_, kv_dim_, use_bias_, INFINICORE_NN_MODULE_INIT(v_proj, hidden_size_, kv_dim_, use_bias_,
dtype, device); dtype, device, tp_rank, tp_size);
// Output projection uses attention_output_bias (can be different from qkv) // Output projection uses attention_output_bias (can be different from qkv)
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_, INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_,
dtype, device); dtype, device, tp_rank, tp_size, rank_info.comm);
} }
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
......
#pragma once #pragma once
#include "llama_config.hpp"
#include "cache/kv_cache.hpp" #include "cache/kv_cache.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/device.hpp"
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rope.hpp" #include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "infinicore/device.hpp" #include "llama_config.hpp"
#include <algorithm> #include <algorithm>
#include <utility>
#include <memory> #include <memory>
#include <utility>
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama { namespace infinilm::models::llama {
...@@ -32,9 +34,11 @@ public: ...@@ -32,9 +34,11 @@ public:
* @param layer_idx Layer index for cache access * @param layer_idx Layer index for cache access
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
LlamaAttention(const LlamaConfig &config, const infinicore::Device &device, LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
infinicore::DataType dtype = infinicore::DataType::F32); infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: compute attention * @brief Forward pass: compute attention
...@@ -66,10 +70,12 @@ public: ...@@ -66,10 +70,12 @@ public:
protected: protected:
// Projection layers // Projection layers
INFINICORE_NN_MODULE(infinicore::nn::Linear, q_proj); INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, q_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, k_proj); INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, k_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, v_proj); INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, v_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, o_proj); INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);
engine::distributed::RankInfo rank_info_;
// Shared Rotary Position Embeddings (RoPE) // Shared Rotary Position Embeddings (RoPE)
std::shared_ptr<infinicore::nn::RoPE> rotary_emb_; std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
namespace infinilm::models::llama { namespace infinilm::models::llama {
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, const infinicore::Device &device, LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
infinicore::DataType dtype) infinicore::DataType dtype,
: layer_idx_(layer_idx) { engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx) , rank_info_(rank_info){
// Initialize layer normalization layers // Initialize layer normalization layers
INFINICORE_NN_MODULE_INIT(input_layernorm, config.hidden_size, config.rms_norm_eps, INFINICORE_NN_MODULE_INIT(input_layernorm, config.hidden_size, config.rms_norm_eps,
dtype, device); dtype, device);
...@@ -15,8 +16,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, const infinicore ...@@ -15,8 +16,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, const infinicore
dtype, device); dtype, device);
// Initialize attention and MLP modules // Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, dtype); INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, dtype, rank_info_);
INFINICORE_NN_MODULE_INIT(mlp, config, device, dtype); INFINICORE_NN_MODULE_INIT(mlp, config, device, dtype, rank_info_);
} }
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
......
#pragma once #pragma once
#include "llama_config.hpp" #include "infinicore/device.hpp"
#include "llama_attention.hpp"
#include "llama_mlp.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp" #include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "infinicore/device.hpp" #include "llama_attention.hpp"
#include "llama_config.hpp"
#include "llama_mlp.hpp"
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama { namespace infinilm::models::llama {
...@@ -31,9 +33,11 @@ public: ...@@ -31,9 +33,11 @@ public:
* @param layer_idx Layer index for cache management and debugging * @param layer_idx Layer index for cache management and debugging
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
LlamaDecoderLayer(const LlamaConfig &config, const infinicore::Device &device, LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
infinicore::DataType dtype = infinicore::DataType::F32); infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: process one decoder layer * @brief Forward pass: process one decoder layer
...@@ -58,7 +62,6 @@ public: ...@@ -58,7 +62,6 @@ public:
} }
} }
protected: protected:
// Layer normalization // Layer normalization
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm);
...@@ -67,6 +70,7 @@ protected: ...@@ -67,6 +70,7 @@ protected:
// Attention and MLP // Attention and MLP
INFINICORE_NN_MODULE(LlamaAttention, self_attn); INFINICORE_NN_MODULE(LlamaAttention, self_attn);
INFINICORE_NN_MODULE(LlamaMLP, mlp); INFINICORE_NN_MODULE(LlamaMLP, mlp);
engine::distributed::RankInfo rank_info_;
private: private:
size_t layer_idx_; // Layer index for cache management and debugging size_t layer_idx_; // Layer index for cache management and debugging
......
#include "llama_for_causal_lm.hpp" #include "llama_for_causal_lm.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp" #include "infinicore/ops.hpp"
#include "infinicore/context/context.hpp"
#include <iostream> #include <iostream>
namespace infinilm::models::llama { namespace infinilm::models::llama {
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device, LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
infinicore::DataType dtype) { const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info) {
// Initialize module's device_ member // Initialize module's device_ member
device_ = device; device_ = device;
// Initialize base model // Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, dtype); INFINICORE_NN_MODULE_INIT(model, config, device, dtype, rank_info);
// Initialize language modeling head // Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama { namespace infinilm::models::llama {
/** /**
...@@ -27,8 +29,10 @@ public: ...@@ -27,8 +29,10 @@ public:
* @param device Device to create tensors on * @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to BF16) * @param dtype Optional data type for model parameters (defaults to BF16)
*/ */
LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device, LlamaForCausalLM(const LlamaConfig &config,
infinicore::DataType dtype = infinicore::DataType::BF16); const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::BF16,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: compute language modeling logits * @brief Forward pass: compute language modeling logits
......
...@@ -4,18 +4,24 @@ ...@@ -4,18 +4,24 @@
namespace infinilm::models::llama { namespace infinilm::models::llama {
LlamaMLP::LlamaMLP(const LlamaConfig &config, const infinicore::Device &device, LlamaMLP::LlamaMLP(const LlamaConfig &config,
infinicore::DataType dtype) const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info)
: hidden_size_(config.hidden_size), : hidden_size_(config.hidden_size),
intermediate_size_(config.intermediate_size), intermediate_size_(config.intermediate_size),
use_bias_(config.mlp_bias) { use_bias_(config.mlp_bias), rank_info_(rank_info) {
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
// Initialize projection layers // Initialize projection layers
INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size_, intermediate_size_, use_bias_, INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size_, intermediate_size_, use_bias_,
dtype, device); dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(up_proj, hidden_size_, intermediate_size_, use_bias_, INFINICORE_NN_MODULE_INIT(up_proj, hidden_size_, intermediate_size_, use_bias_,
dtype, device); dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_, INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_,
dtype, device); dtype, device, tp_rank, tp_size, rank_info.comm);
} }
infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const { infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
......
#pragma once #pragma once
#include "llama_config.hpp" #include "infinicore/device.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/linear.hpp" #include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "infinicore/device.hpp" #include "llama_config.hpp"
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama { namespace infinilm::models::llama {
...@@ -28,8 +30,10 @@ public: ...@@ -28,8 +30,10 @@ public:
* @param device Device to create tensors on * @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
LlamaMLP(const LlamaConfig &config, const infinicore::Device &device, LlamaMLP(const LlamaConfig &config,
infinicore::DataType dtype = infinicore::DataType::F32); const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: compute MLP output * @brief Forward pass: compute MLP output
...@@ -45,15 +49,16 @@ public: ...@@ -45,15 +49,16 @@ public:
protected: protected:
// Projection layers // Projection layers
INFINICORE_NN_MODULE(infinicore::nn::Linear, gate_proj); INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, gate_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, up_proj); INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, up_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, down_proj); INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, down_proj);
engine::distributed::RankInfo rank_info_;
private: private:
size_t hidden_size_; size_t hidden_size_;
size_t intermediate_size_; size_t intermediate_size_;
bool use_bias_; bool use_bias_;
}; };
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -6,8 +6,10 @@ ...@@ -6,8 +6,10 @@
namespace infinilm::models::llama { namespace infinilm::models::llama {
LlamaModel::LlamaModel(const LlamaConfig &config, const infinicore::Device &device, LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore::DataType dtype) const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info)
: config_(config) { : config_(config) {
// Initialize token embeddings // Initialize token embeddings
INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size,
...@@ -20,7 +22,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config, const infinicore::Device &devi ...@@ -20,7 +22,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config, const infinicore::Device &devi
layers_.reserve(config.num_hidden_layers); layers_.reserve(config.num_hidden_layers);
for (size_t i = 0; i < config.num_hidden_layers; ++i) { for (size_t i = 0; i < config.num_hidden_layers; ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>( layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), config, device, i, dtype)); "layers." + std::to_string(i), config, device, i, dtype, rank_info));
} }
// Initialize final layer normalization // Initialize final layer normalization
...@@ -57,8 +59,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -57,8 +59,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// First time: create cache // First time: create cache
cache_ = std::make_unique<infinilm::cache::DynamicCache>( cache_ = std::make_unique<infinilm::cache::DynamicCache>(
config_.num_hidden_layers, config_.num_hidden_layers,
config_.max_position_embeddings config_.max_position_embeddings);
);
} }
cache_to_use = cache_.get(); cache_to_use = cache_.get();
} }
...@@ -76,7 +77,6 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, ...@@ -76,7 +77,6 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// Logging moved to decoder layer for post-attention normalization // Logging moved to decoder layer for post-attention normalization
} }
// 3. Apply final layer normalization to last token only (aligns with transformers) // 3. Apply final layer normalization to last token only (aligns with transformers)
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size] // Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
......
#pragma once #pragma once
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp"
#include "cache/kv_cache.hpp" #include "cache/kv_cache.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/device.hpp"
#include "infinicore/nn/embedding.hpp" #include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp" #include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp" #include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "infinicore/device.hpp" #include "llama_config.hpp"
#include <vector> #include "llama_decoder_layer.hpp"
#include <memory> #include <memory>
#include <vector>
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama { namespace infinilm::models::llama {
...@@ -34,8 +36,10 @@ public: ...@@ -34,8 +36,10 @@ public:
* @param device Device to create tensors on * @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32) * @param dtype Optional data type for model parameters (defaults to F32)
*/ */
LlamaModel(const LlamaConfig &config, const infinicore::Device &device, LlamaModel(const LlamaConfig &config,
infinicore::DataType dtype = infinicore::DataType::F32); const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
* @brief Forward pass: process input through the model * @brief Forward pass: process input through the model
...@@ -49,12 +53,10 @@ public: ...@@ -49,12 +53,10 @@ public:
const infinicore::Tensor &position_ids, const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const; void *kv_cache = nullptr) const;
// Module information // Module information
const LlamaConfig &config() const { return config_; } const LlamaConfig &config() const { return config_; }
size_t num_layers() const { return config_.num_hidden_layers; } size_t num_layers() const { return config_.num_hidden_layers; }
/** /**
* @brief Reset the internal cache to a specific position * @brief Reset the internal cache to a specific position
* This should be called when starting a new generation sequence to prevent state * This should be called when starting a new generation sequence to prevent state
......
...@@ -6,7 +6,7 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(const std::any ...@@ -6,7 +6,7 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(const std::any
if (config.type() == typeid(models::llama::LlamaConfig)) { if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config); const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
return std::make_shared<models::llama::LlamaForCausalLM>(llama_config, rank_info.device); return std::make_shared<models::llama::LlamaForCausalLM>(llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);
} else { } else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
} }
......
...@@ -65,12 +65,9 @@ inline void bind_llama(py::module &m) { ...@@ -65,12 +65,9 @@ inline void bind_llama(py::module &m) {
.def_readwrite("pretraining_tp", &LlamaConfig::pretraining_tp) .def_readwrite("pretraining_tp", &LlamaConfig::pretraining_tp)
.def_readwrite("name_or_path", &LlamaConfig::name_or_path) .def_readwrite("name_or_path", &LlamaConfig::name_or_path)
.def_readwrite("pad_token_id", &LlamaConfig::pad_token_id) .def_readwrite("pad_token_id", &LlamaConfig::pad_token_id)
.def_property("bos_token_id", .def_property("bos_token_id", [](const LlamaConfig &self) {
[](const LlamaConfig &self) {
// Always return as list to match Python config format // Always return as list to match Python config format
return py::cast(self.bos_token_id); return py::cast(self.bos_token_id); }, [](LlamaConfig &self, py::object value) {
},
[](LlamaConfig &self, py::object value) {
// Accept both single int and list // Accept both single int and list
if (py::isinstance<py::int_>(value)) { if (py::isinstance<py::int_>(value)) {
self.bos_token_id = {value.cast<int64_t>()}; self.bos_token_id = {value.cast<int64_t>()};
...@@ -78,14 +75,10 @@ inline void bind_llama(py::module &m) { ...@@ -78,14 +75,10 @@ inline void bind_llama(py::module &m) {
self.bos_token_id = value.cast<std::vector<int64_t>>(); self.bos_token_id = value.cast<std::vector<int64_t>>();
} else { } else {
throw py::type_error("bos_token_id must be int or list of ints"); throw py::type_error("bos_token_id must be int or list of ints");
} } })
}) .def_property("eos_token_id", [](const LlamaConfig &self) {
.def_property("eos_token_id",
[](const LlamaConfig &self) {
// Always return as list to match Python config format // Always return as list to match Python config format
return py::cast(self.eos_token_id); return py::cast(self.eos_token_id); }, [](LlamaConfig &self, py::object value) {
},
[](LlamaConfig &self, py::object value) {
// Accept both single int and list // Accept both single int and list
if (py::isinstance<py::int_>(value)) { if (py::isinstance<py::int_>(value)) {
self.eos_token_id = {value.cast<int64_t>()}; self.eos_token_id = {value.cast<int64_t>()};
...@@ -93,8 +86,7 @@ inline void bind_llama(py::module &m) { ...@@ -93,8 +86,7 @@ inline void bind_llama(py::module &m) {
self.eos_token_id = value.cast<std::vector<int64_t>>(); self.eos_token_id = value.cast<std::vector<int64_t>>();
} else { } else {
throw py::type_error("eos_token_id must be int or list of ints"); throw py::type_error("eos_token_id must be int or list of ints");
} } })
})
.def("validate", &LlamaConfig::validate) .def("validate", &LlamaConfig::validate)
.def("kv_dim", &LlamaConfig::kv_dim) .def("kv_dim", &LlamaConfig::kv_dim)
// Add __dir__ to make attributes discoverable via dir() in Python // Add __dir__ to make attributes discoverable via dir() in Python
...@@ -126,26 +118,12 @@ inline void bind_llama(py::module &m) { ...@@ -126,26 +118,12 @@ inline void bind_llama(py::module &m) {
dir_list.append("eos_token_id"); dir_list.append("eos_token_id");
dir_list.append("validate"); dir_list.append("validate");
dir_list.append("kv_dim"); dir_list.append("kv_dim");
return dir_list; return dir_list; });
});
// Note: Device is already bound in InfiniCore bindings, so we don't need to bind it here // Note: Device is already bound in InfiniCore bindings, so we don't need to bind it here
// Bind LlamaForCausalLM // Bind LlamaForCausalLM
py::class_<LlamaForCausalLM, std::shared_ptr<LlamaForCausalLM>>(m, "LlamaForCausalLM") py::class_<LlamaForCausalLM, std::shared_ptr<LlamaForCausalLM>>(m, "LlamaForCausalLM")
.def(py::init([](const LlamaConfig &config, const Device &device, py::object dtype_obj) {
infinicore::DataType dtype = infinicore::DataType::F32;
if (!dtype_obj.is_none()) {
// Extract dtype from Python object
if (py::hasattr(dtype_obj, "_underlying")) {
dtype = dtype_obj.attr("_underlying").cast<infinicore::DataType>();
} else {
dtype = dtype_obj.cast<infinicore::DataType>();
}
}
return std::make_shared<LlamaForCausalLM>(config, device, dtype);
}),
py::arg("config"), py::arg("device"), py::arg("dtype") = py::none())
.def("state_dict", [](const LlamaForCausalLM &model) { .def("state_dict", [](const LlamaForCausalLM &model) {
// Return a dictionary containing references to the whole state of the module. // Return a dictionary containing references to the whole state of the module.
auto state_dict = model.state_dict(); auto state_dict = model.state_dict();
...@@ -182,11 +160,9 @@ inline void bind_llama(py::module &m) { ...@@ -182,11 +160,9 @@ inline void bind_llama(py::module &m) {
} }
model.load_state_dict(cpp_state_dict); }, py::arg("state_dict")) model.load_state_dict(cpp_state_dict); }, py::arg("state_dict"))
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal) .def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def( .def("reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) {
"reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) {
// Reset the internal cache to prevent state from persisting between generations // Reset the internal cache to prevent state from persisting between generations
model.model().reset_cache(pos); model.model().reset_cache(pos); }, py::arg("pos") = 0, "Reset the internal cache to a specific position (clears state between generations)")
}, py::arg("pos") = 0, "Reset the internal cache to a specific position (clears state between generations)")
.def("forward", [](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_cache = py::none()) { .def("forward", [](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_cache = py::none()) {
// Helper to extract C++ tensor from Python InfiniCore tensor // Helper to extract C++ tensor from Python InfiniCore tensor
auto get_tensor = [](py::object obj) -> infinicore::Tensor { auto get_tensor = [](py::object obj) -> infinicore::Tensor {
...@@ -219,8 +195,7 @@ inline void bind_llama(py::module &m) { ...@@ -219,8 +195,7 @@ inline void bind_llama(py::module &m) {
} }
} }
return model.forward(infini_input_ids, infini_position_ids, kv_cache_ptr); return model.forward(infini_input_ids, infini_position_ids, kv_cache_ptr); }, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
}, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
} }
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -55,7 +55,7 @@ def get_args(): ...@@ -55,7 +55,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
default="python", default="cpp",
help="python or cpp model", help="python or cpp model",
) )
parser.add_argument( parser.add_argument(
...@@ -79,7 +79,7 @@ def get_args(): ...@@ -79,7 +79,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--tp", "--tp",
type=int, type=int,
default=None, default=1,
help="total rank for tensor parallel", help="total rank for tensor parallel",
) )
...@@ -93,6 +93,7 @@ def test( ...@@ -93,6 +93,7 @@ def test(
infini_dtype=infinicore.bfloat16, infini_dtype=infinicore.bfloat16,
infini_device=infinicore.device("cpu", 0), infini_device=infinicore.device("cpu", 0),
backend="python", backend="python",
tp=1,
): ):
model_path = os.path.expanduser(model_path) model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -103,7 +104,7 @@ def test( ...@@ -103,7 +104,7 @@ def test(
device=infini_device, device=infini_device,
dtype=infini_dtype, dtype=infini_dtype,
backend=backend, backend=backend,
distributed_config=DistConfig(args.tp), distributed_config=DistConfig(tp),
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -134,7 +135,6 @@ def test( ...@@ -134,7 +135,6 @@ def test(
] ]
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# token编码 # token编码
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -201,6 +201,7 @@ if __name__ == "__main__": ...@@ -201,6 +201,7 @@ if __name__ == "__main__":
model_path = args.model_path model_path = args.model_path
max_new_tokens = args.max_new_tokens max_new_tokens = args.max_new_tokens
backend = args.backend backend = args.backend
tp = args.tp
infini_device = infinicore.device(device_str, 0) infini_device = infinicore.device(device_str, 0)
if args.dtype == "float32": if args.dtype == "float32":
...@@ -219,4 +220,5 @@ if __name__ == "__main__": ...@@ -219,4 +220,5 @@ if __name__ == "__main__":
infini_device=infini_device, infini_device=infini_device,
infini_dtype=infini_dtype, infini_dtype=infini_dtype,
backend=backend, backend=backend,
tp=tp,
) )
...@@ -124,12 +124,13 @@ class GenerationMixin: ...@@ -124,12 +124,13 @@ class GenerationMixin:
# This prevents state from persisting between different questions/prompts # This prevents state from persisting between different questions/prompts
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# Check if this is a cpp backend model (has _model attribute with reset_cache method) # Check if this is a cpp backend model (has _model attribute with reset_cache method)
if hasattr(self, '_model') and hasattr(self._model, 'reset_cache'): if hasattr(self, "_model") and hasattr(self._model, "reset_cache"):
try: try:
self._model.reset_cache() self._model.reset_cache()
except Exception as e: except Exception as e:
# If reset_cache fails, log but continue (shouldn't happen) # If reset_cache fails, log but continue (shouldn't happen)
import warnings import warnings
warnings.warn(f"Failed to reset cache: {e}") warnings.warn(f"Failed to reset cache: {e}")
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
...@@ -210,6 +211,7 @@ class GenerationMixin: ...@@ -210,6 +211,7 @@ class GenerationMixin:
start_time = time.time() start_time = time.time()
logits = self(**model_inputs) logits = self(**model_inputs)
infinicore.sync_device()
# -------------------------------------------------------------------------- # # -------------------------------------------------------------------------- #
# 处理输出 # 处理输出
......
...@@ -68,7 +68,9 @@ class LlamaConfig: ...@@ -68,7 +68,9 @@ class LlamaConfig:
pass pass
# Handle num_key_value_heads with validation # Handle num_key_value_heads with validation
python_num_kv_heads = getattr(self._python_config, "num_key_value_heads", None) python_num_kv_heads = getattr(
self._python_config, "num_key_value_heads", None
)
if python_num_kv_heads is None or python_num_kv_heads == 0: if python_num_kv_heads is None or python_num_kv_heads == 0:
self._cpp_config.num_key_value_heads = ( self._cpp_config.num_key_value_heads = (
self._cpp_config.num_attention_heads self._cpp_config.num_attention_heads
...@@ -80,8 +82,14 @@ class LlamaConfig: ...@@ -80,8 +82,14 @@ class LlamaConfig:
python_head_dim = getattr(self._python_config, "head_dim", None) python_head_dim = getattr(self._python_config, "head_dim", None)
if python_head_dim is None or python_head_dim == 0: if python_head_dim is None or python_head_dim == 0:
# Compute from hidden_size and num_attention_heads # Compute from hidden_size and num_attention_heads
if self._cpp_config.hidden_size > 0 and self._cpp_config.num_attention_heads > 0: if (
computed_head_dim = self._cpp_config.hidden_size // self._cpp_config.num_attention_heads self._cpp_config.hidden_size > 0
and self._cpp_config.num_attention_heads > 0
):
computed_head_dim = (
self._cpp_config.hidden_size
// self._cpp_config.num_attention_heads
)
self._cpp_config.head_dim = computed_head_dim self._cpp_config.head_dim = computed_head_dim
else: else:
raise ValueError( raise ValueError(
...@@ -92,10 +100,17 @@ class LlamaConfig: ...@@ -92,10 +100,17 @@ class LlamaConfig:
# Use from Python config # Use from Python config
self._cpp_config.head_dim = python_head_dim self._cpp_config.head_dim = python_head_dim
# Validate it matches expected value (warn but allow for flexibility) # Validate it matches expected value (warn but allow for flexibility)
if self._cpp_config.hidden_size > 0 and self._cpp_config.num_attention_heads > 0: if (
expected_head_dim = self._cpp_config.hidden_size // self._cpp_config.num_attention_heads self._cpp_config.hidden_size > 0
and self._cpp_config.num_attention_heads > 0
):
expected_head_dim = (
self._cpp_config.hidden_size
// self._cpp_config.num_attention_heads
)
if self._cpp_config.head_dim != expected_head_dim: if self._cpp_config.head_dim != expected_head_dim:
import warnings import warnings
warnings.warn( warnings.warn(
f"head_dim ({self._cpp_config.head_dim}) != hidden_size/num_attention_heads ({expected_head_dim}). " f"head_dim ({self._cpp_config.head_dim}) != hidden_size/num_attention_heads ({expected_head_dim}). "
f"Using head_dim from config." f"Using head_dim from config."
...@@ -107,10 +122,13 @@ class LlamaConfig: ...@@ -107,10 +122,13 @@ class LlamaConfig:
# Validate config after setting all values (especially important for jiuge models) # Validate config after setting all values (especially important for jiuge models)
if not self._cpp_config.validate(): if not self._cpp_config.validate():
raise ValueError("C++ LlamaConfig validation failed. Check config values.") raise ValueError(
"C++ LlamaConfig validation failed. Check config values."
)
# Log key config values for debugging (especially useful for jiuge models) # Log key config values for debugging (especially useful for jiuge models)
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info( logger.info(
f"LlamaConfig ({self._python_config.model_type}) C++ LlamaConfig created: vocab_size={self._cpp_config.vocab_size}, " f"LlamaConfig ({self._python_config.model_type}) C++ LlamaConfig created: vocab_size={self._cpp_config.vocab_size}, "
...@@ -189,6 +207,9 @@ class LlamaForCausalLM(GenerationMixin): ...@@ -189,6 +207,9 @@ class LlamaForCausalLM(GenerationMixin):
for name, param in state_dict.items(): for name, param in state_dict.items():
self._model.load_param(name, param._underlying) self._model.load_param(name, param._underlying)
def load_param(self, name: str, weight: infinicore.Tensor):
self._model.load_param(name, weight._underlying)
def get_parameter(self, name): def get_parameter(self, name):
""" """
Get a parameter tensor by name Get a parameter tensor by name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment