Commit 39790c1e authored by Your Name's avatar Your Name
Browse files

issue/111 - 9g7b分布式

parent 9c256a17
......@@ -9,12 +9,15 @@
#include <iostream>
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <vector>
namespace infinilm::models::llama {
LlamaAttention::LlamaAttention(const LlamaConfig &config, const infinicore::Device &device,
LlamaAttention::LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
infinicore::DataType dtype)
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info)
: layer_idx_(layer_idx),
hidden_size_(config.hidden_size),
num_attention_heads_(config.num_attention_heads),
......@@ -23,17 +26,31 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, const infinicore::Devi
kv_dim_(config.kv_dim()),
use_bias_(config.attention_bias),
use_output_bias_(config.attention_output_bias),
max_position_embeddings_(config.max_position_embeddings) {
max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) {
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
int num_attention_heads = config.num_attention_heads;
int num_key_value_heads = config.num_key_value_heads;
if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) {
this->num_attention_heads_ = num_attention_heads / tp_size;
this->num_key_value_heads_ = num_key_value_heads / tp_size;
} else {
throw std::runtime_error("num_attention_heads / tp_size error.");
}
// Initialize projection layers
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, hidden_size_, use_bias_,
dtype, device);
dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(k_proj, hidden_size_, kv_dim_, use_bias_,
dtype, device);
dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(v_proj, hidden_size_, kv_dim_, use_bias_,
dtype, device);
dtype, device, tp_rank, tp_size);
// Output projection uses attention_output_bias (can be different from qkv)
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_,
dtype, device);
dtype, device, tp_rank, tp_size, rank_info.comm);
}
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
......
#pragma once
#include "llama_config.hpp"
#include "cache/kv_cache.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/device.hpp"
#include "llama_config.hpp"
#include <algorithm>
#include <utility>
#include <memory>
#include <utility>
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama {
......@@ -32,9 +34,11 @@ public:
* @param layer_idx Layer index for cache access
* @param dtype Optional data type for model parameters (defaults to F32)
*/
LlamaAttention(const LlamaConfig &config, const infinicore::Device &device,
LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
infinicore::DataType dtype = infinicore::DataType::F32);
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: compute attention
......@@ -66,10 +70,12 @@ public:
protected:
// Projection layers
INFINICORE_NN_MODULE(infinicore::nn::Linear, q_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, k_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, v_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, o_proj);
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, q_proj);
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, k_proj);
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, v_proj);
INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj);
engine::distributed::RankInfo rank_info_;
// Shared Rotary Position Embeddings (RoPE)
std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;
......
......@@ -4,10 +4,11 @@
namespace infinilm::models::llama {
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, const infinicore::Device &device,
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
infinicore::DataType dtype)
: layer_idx_(layer_idx) {
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx) , rank_info_(rank_info){
// Initialize layer normalization layers
INFINICORE_NN_MODULE_INIT(input_layernorm, config.hidden_size, config.rms_norm_eps,
dtype, device);
......@@ -15,8 +16,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, const infinicore
dtype, device);
// Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, dtype);
INFINICORE_NN_MODULE_INIT(mlp, config, device, dtype);
INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, dtype, rank_info_);
INFINICORE_NN_MODULE_INIT(mlp, config, device, dtype, rank_info_);
}
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
......
#pragma once
#include "llama_config.hpp"
#include "llama_attention.hpp"
#include "llama_mlp.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/device.hpp"
#include "llama_attention.hpp"
#include "llama_config.hpp"
#include "llama_mlp.hpp"
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama {
......@@ -31,9 +33,11 @@ public:
* @param layer_idx Layer index for cache management and debugging
* @param dtype Optional data type for model parameters (defaults to F32)
*/
LlamaDecoderLayer(const LlamaConfig &config, const infinicore::Device &device,
LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device,
size_t layer_idx,
infinicore::DataType dtype = infinicore::DataType::F32);
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: process one decoder layer
......@@ -58,7 +62,6 @@ public:
}
}
protected:
// Layer normalization
INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm);
......@@ -67,6 +70,7 @@ protected:
// Attention and MLP
INFINICORE_NN_MODULE(LlamaAttention, self_attn);
INFINICORE_NN_MODULE(LlamaMLP, mlp);
engine::distributed::RankInfo rank_info_;
private:
size_t layer_idx_; // Layer index for cache management and debugging
......
#include "llama_for_causal_lm.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/context/context.hpp"
#include <iostream>
namespace infinilm::models::llama {
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device,
infinicore::DataType dtype) {
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info) {
// Initialize module's device_ member
device_ = device;
// Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, dtype);
INFINICORE_NN_MODULE_INIT(model, config, device, dtype, rank_info);
// Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens
......
......@@ -8,6 +8,8 @@
#include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp"
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama {
/**
......@@ -27,8 +29,10 @@ public:
* @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to BF16)
*/
LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::BF16);
LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::BF16,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: compute language modeling logits
......
......@@ -4,18 +4,24 @@
namespace infinilm::models::llama {
LlamaMLP::LlamaMLP(const LlamaConfig &config, const infinicore::Device &device,
infinicore::DataType dtype)
LlamaMLP::LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info)
: hidden_size_(config.hidden_size),
intermediate_size_(config.intermediate_size),
use_bias_(config.mlp_bias) {
use_bias_(config.mlp_bias), rank_info_(rank_info) {
int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size;
// Initialize projection layers
INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size_, intermediate_size_, use_bias_,
dtype, device);
dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(up_proj, hidden_size_, intermediate_size_, use_bias_,
dtype, device);
dtype, device, tp_rank, tp_size);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_,
dtype, device);
dtype, device, tp_rank, tp_size, rank_info.comm);
}
infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
......
#pragma once
#include "llama_config.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/device.hpp"
#include "llama_config.hpp"
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama {
......@@ -28,8 +30,10 @@ public:
* @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32)
*/
LlamaMLP(const LlamaConfig &config, const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32);
LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: compute MLP output
......@@ -45,15 +49,16 @@ public:
protected:
// Projection layers
INFINICORE_NN_MODULE(infinicore::nn::Linear, gate_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, up_proj);
INFINICORE_NN_MODULE(infinicore::nn::Linear, down_proj);
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, gate_proj);
INFINICORE_NN_MODULE(infinicore::nn::ColumnParallelLinear, up_proj);
INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, down_proj);
engine::distributed::RankInfo rank_info_;
private:
size_t hidden_size_;
size_t intermediate_size_;
bool use_bias_;
};
} // namespace infinilm::models::llama
......@@ -6,8 +6,10 @@
namespace infinilm::models::llama {
LlamaModel::LlamaModel(const LlamaConfig &config, const infinicore::Device &device,
infinicore::DataType dtype)
LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info)
: config_(config) {
// Initialize token embeddings
INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size,
......@@ -20,7 +22,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config, const infinicore::Device &devi
layers_.reserve(config.num_hidden_layers);
for (size_t i = 0; i < config.num_hidden_layers; ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), config, device, i, dtype));
"layers." + std::to_string(i), config, device, i, dtype, rank_info));
}
// Initialize final layer normalization
......@@ -57,8 +59,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// First time: create cache
cache_ = std::make_unique<infinilm::cache::DynamicCache>(
config_.num_hidden_layers,
config_.max_position_embeddings
);
config_.max_position_embeddings);
}
cache_to_use = cache_.get();
}
......@@ -76,7 +77,6 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// Logging moved to decoder layer for post-attention normalization
}
// 3. Apply final layer normalization to last token only (aligns with transformers)
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
......
#pragma once
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp"
#include "cache/kv_cache.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/device.hpp"
#include <vector>
#include "llama_config.hpp"
#include "llama_decoder_layer.hpp"
#include <memory>
#include <vector>
#include "../../engine/distributed/distributed.hpp"
namespace infinilm::models::llama {
......@@ -34,8 +36,10 @@ public:
* @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32)
*/
LlamaModel(const LlamaConfig &config, const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32);
LlamaModel(const LlamaConfig &config,
const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/**
* @brief Forward pass: process input through the model
......@@ -49,12 +53,10 @@ public:
const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const;
// Module information
const LlamaConfig &config() const { return config_; }
size_t num_layers() const { return config_.num_hidden_layers; }
/**
* @brief Reset the internal cache to a specific position
* This should be called when starting a new generation sequence to prevent state
......
......@@ -6,7 +6,7 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(const std::any
if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
return std::make_shared<models::llama::LlamaForCausalLM>(llama_config, rank_info.device);
return std::make_shared<models::llama::LlamaForCausalLM>(llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
......
......@@ -65,12 +65,9 @@ inline void bind_llama(py::module &m) {
.def_readwrite("pretraining_tp", &LlamaConfig::pretraining_tp)
.def_readwrite("name_or_path", &LlamaConfig::name_or_path)
.def_readwrite("pad_token_id", &LlamaConfig::pad_token_id)
.def_property("bos_token_id",
[](const LlamaConfig &self) {
.def_property("bos_token_id", [](const LlamaConfig &self) {
// Always return as list to match Python config format
return py::cast(self.bos_token_id);
},
[](LlamaConfig &self, py::object value) {
return py::cast(self.bos_token_id); }, [](LlamaConfig &self, py::object value) {
// Accept both single int and list
if (py::isinstance<py::int_>(value)) {
self.bos_token_id = {value.cast<int64_t>()};
......@@ -78,14 +75,10 @@ inline void bind_llama(py::module &m) {
self.bos_token_id = value.cast<std::vector<int64_t>>();
} else {
throw py::type_error("bos_token_id must be int or list of ints");
}
})
.def_property("eos_token_id",
[](const LlamaConfig &self) {
} })
.def_property("eos_token_id", [](const LlamaConfig &self) {
// Always return as list to match Python config format
return py::cast(self.eos_token_id);
},
[](LlamaConfig &self, py::object value) {
return py::cast(self.eos_token_id); }, [](LlamaConfig &self, py::object value) {
// Accept both single int and list
if (py::isinstance<py::int_>(value)) {
self.eos_token_id = {value.cast<int64_t>()};
......@@ -93,8 +86,7 @@ inline void bind_llama(py::module &m) {
self.eos_token_id = value.cast<std::vector<int64_t>>();
} else {
throw py::type_error("eos_token_id must be int or list of ints");
}
})
} })
.def("validate", &LlamaConfig::validate)
.def("kv_dim", &LlamaConfig::kv_dim)
// Add __dir__ to make attributes discoverable via dir() in Python
......@@ -126,26 +118,12 @@ inline void bind_llama(py::module &m) {
dir_list.append("eos_token_id");
dir_list.append("validate");
dir_list.append("kv_dim");
return dir_list;
});
return dir_list; });
// Note: Device is already bound in InfiniCore bindings, so we don't need to bind it here
// Bind LlamaForCausalLM
py::class_<LlamaForCausalLM, std::shared_ptr<LlamaForCausalLM>>(m, "LlamaForCausalLM")
.def(py::init([](const LlamaConfig &config, const Device &device, py::object dtype_obj) {
infinicore::DataType dtype = infinicore::DataType::F32;
if (!dtype_obj.is_none()) {
// Extract dtype from Python object
if (py::hasattr(dtype_obj, "_underlying")) {
dtype = dtype_obj.attr("_underlying").cast<infinicore::DataType>();
} else {
dtype = dtype_obj.cast<infinicore::DataType>();
}
}
return std::make_shared<LlamaForCausalLM>(config, device, dtype);
}),
py::arg("config"), py::arg("device"), py::arg("dtype") = py::none())
.def("state_dict", [](const LlamaForCausalLM &model) {
// Return a dictionary containing references to the whole state of the module.
auto state_dict = model.state_dict();
......@@ -182,11 +160,9 @@ inline void bind_llama(py::module &m) {
}
model.load_state_dict(cpp_state_dict); }, py::arg("state_dict"))
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def(
"reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) {
.def("reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) {
// Reset the internal cache to prevent state from persisting between generations
model.model().reset_cache(pos);
}, py::arg("pos") = 0, "Reset the internal cache to a specific position (clears state between generations)")
model.model().reset_cache(pos); }, py::arg("pos") = 0, "Reset the internal cache to a specific position (clears state between generations)")
.def("forward", [](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_cache = py::none()) {
// Helper to extract C++ tensor from Python InfiniCore tensor
auto get_tensor = [](py::object obj) -> infinicore::Tensor {
......@@ -219,8 +195,7 @@ inline void bind_llama(py::module &m) {
}
}
return model.forward(infini_input_ids, infini_position_ids, kv_cache_ptr);
}, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
return model.forward(infini_input_ids, infini_position_ids, kv_cache_ptr); }, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
}
} // namespace infinilm::models::llama
......@@ -55,7 +55,7 @@ def get_args():
parser.add_argument(
"--backend",
type=str,
default="python",
default="cpp",
help="python or cpp model",
)
parser.add_argument(
......@@ -79,7 +79,7 @@ def get_args():
parser.add_argument(
"--tp",
type=int,
default=None,
default=1,
help="total rank for tensor parallel",
)
......@@ -93,6 +93,7 @@ def test(
infini_dtype=infinicore.bfloat16,
infini_device=infinicore.device("cpu", 0),
backend="python",
tp=1,
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -103,7 +104,7 @@ def test(
device=infini_device,
dtype=infini_dtype,
backend=backend,
distributed_config=DistConfig(args.tp),
distributed_config=DistConfig(tp),
)
# ---------------------------------------------------------------------------- #
......@@ -134,7 +135,6 @@ def test(
]
)
# ---------------------------------------------------------------------------- #
# token编码
# ---------------------------------------------------------------------------- #
......@@ -201,6 +201,7 @@ if __name__ == "__main__":
model_path = args.model_path
max_new_tokens = args.max_new_tokens
backend = args.backend
tp = args.tp
infini_device = infinicore.device(device_str, 0)
if args.dtype == "float32":
......@@ -219,4 +220,5 @@ if __name__ == "__main__":
infini_device=infini_device,
infini_dtype=infini_dtype,
backend=backend,
tp=tp,
)
......@@ -124,12 +124,13 @@ class GenerationMixin:
# This prevents state from persisting between different questions/prompts
# -------------------------------------------------------------------- #
# Check if this is a cpp backend model (has _model attribute with reset_cache method)
if hasattr(self, '_model') and hasattr(self._model, 'reset_cache'):
if hasattr(self, "_model") and hasattr(self._model, "reset_cache"):
try:
self._model.reset_cache()
except Exception as e:
# If reset_cache fails, log but continue (shouldn't happen)
import warnings
warnings.warn(f"Failed to reset cache: {e}")
# -------------------------------------------------------------------- #
......@@ -210,6 +211,7 @@ class GenerationMixin:
start_time = time.time()
logits = self(**model_inputs)
infinicore.sync_device()
# -------------------------------------------------------------------------- #
# 处理输出
......
......@@ -68,7 +68,9 @@ class LlamaConfig:
pass
# Handle num_key_value_heads with validation
python_num_kv_heads = getattr(self._python_config, "num_key_value_heads", None)
python_num_kv_heads = getattr(
self._python_config, "num_key_value_heads", None
)
if python_num_kv_heads is None or python_num_kv_heads == 0:
self._cpp_config.num_key_value_heads = (
self._cpp_config.num_attention_heads
......@@ -80,8 +82,14 @@ class LlamaConfig:
python_head_dim = getattr(self._python_config, "head_dim", None)
if python_head_dim is None or python_head_dim == 0:
# Compute from hidden_size and num_attention_heads
if self._cpp_config.hidden_size > 0 and self._cpp_config.num_attention_heads > 0:
computed_head_dim = self._cpp_config.hidden_size // self._cpp_config.num_attention_heads
if (
self._cpp_config.hidden_size > 0
and self._cpp_config.num_attention_heads > 0
):
computed_head_dim = (
self._cpp_config.hidden_size
// self._cpp_config.num_attention_heads
)
self._cpp_config.head_dim = computed_head_dim
else:
raise ValueError(
......@@ -92,10 +100,17 @@ class LlamaConfig:
# Use from Python config
self._cpp_config.head_dim = python_head_dim
# Validate it matches expected value (warn but allow for flexibility)
if self._cpp_config.hidden_size > 0 and self._cpp_config.num_attention_heads > 0:
expected_head_dim = self._cpp_config.hidden_size // self._cpp_config.num_attention_heads
if (
self._cpp_config.hidden_size > 0
and self._cpp_config.num_attention_heads > 0
):
expected_head_dim = (
self._cpp_config.hidden_size
// self._cpp_config.num_attention_heads
)
if self._cpp_config.head_dim != expected_head_dim:
import warnings
warnings.warn(
f"head_dim ({self._cpp_config.head_dim}) != hidden_size/num_attention_heads ({expected_head_dim}). "
f"Using head_dim from config."
......@@ -107,10 +122,13 @@ class LlamaConfig:
# Validate config after setting all values (especially important for jiuge models)
if not self._cpp_config.validate():
raise ValueError("C++ LlamaConfig validation failed. Check config values.")
raise ValueError(
"C++ LlamaConfig validation failed. Check config values."
)
# Log key config values for debugging (especially useful for jiuge models)
import logging
logger = logging.getLogger(__name__)
logger.info(
f"LlamaConfig ({self._python_config.model_type}) C++ LlamaConfig created: vocab_size={self._cpp_config.vocab_size}, "
......@@ -189,6 +207,9 @@ class LlamaForCausalLM(GenerationMixin):
for name, param in state_dict.items():
self._model.load_param(name, param._underlying)
def load_param(self, name: str, weight: infinicore.Tensor):
self._model.load_param(name, weight._underlying)
def get_parameter(self, name):
"""
Get a parameter tensor by name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment